If I have below table
----- ----- --- ----- ----- ----- ----- ----- ----- -----
| a| b| id|m2000|m2001|m2002|m2003|m2004|m2005
----- ----- --- ----- ----- ----- ----- ----- ----- -----
|a |world| 1| 0| 0| 1| 0| 0| 1|
----- ----- --- ----- ----- ----- ----- ----- ----- -----
How do I create a new dataframe like below that checks cols m2000 to m2014 and sees if any these fields are 1. It then creates the below table where 10/10 is static. 2002 and 2005 is used as it is only 2 columns between m2000 and m2014 where 1 is in above table.
|id | year | yearend |
|1 | 10/10/2002| 12/12/2005|
|1 | 10/10/2002| 12/12/2005|
code to create first dataframe
from pyspark.shell import spark
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
data2 = [("a", "world", "1", 0, 0, 1,0,0,1),
]
schema = StructType([ \
StructField("a", StringType(), True), \
StructField("b", StringType(), True), \
StructField("id", StringType(), True), \
StructField("m2000", IntegerType(), True), \
StructField("m2001", IntegerType(), True), \
StructField("m2002", IntegerType(), True), \
StructField("m2003", IntegerType(), True), \
StructField("m2004", IntegerType(), True), \
StructField("m2005", IntegerType(), True), \
])
df = spark.createDataFrame(data=data2, schema=schema)
df.printSchema()
df.show(truncate=False)
CodePudding user response:
Assuming a dataframe with a more complete scenario, where there are rows without years to '1' and rows with more '1's:
from pyspark.shell import spark
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
data2 = [("a", "world", "1", 0, 0, 1,0,0,1),
("b", "world", "2", 0, 1, 0,1,0,1),
("c", "world", "3", 0, 0, 0,0,0,0)
]
schema = StructType([ \
StructField("a", StringType(), True), \
StructField("b", StringType(), True), \
StructField("id", StringType(), True), \
StructField("m2000", IntegerType(), True), \
StructField("m2001", IntegerType(), True), \
StructField("m2002", IntegerType(), True), \
StructField("m2003", IntegerType(), True), \
StructField("m2004", IntegerType(), True), \
StructField("m2005", IntegerType(), True), \
])
df = spark.createDataFrame(data=data2, schema=schema)
df.printSchema()
df.show(truncate=False)
a | b | id | m2000 | m2001 | m2002 | m2003 | m2004 | m2005 | |
---|---|---|---|---|---|---|---|---|---|
0 | a | world | 1 | 0 | 0 | 1 | 0 | 0 | 1 |
1 | b | world | 2 | 0 | 1 | 0 | 1 | 0 | 1 |
2 | c | world | 3 | 0 | 0 | 0 | 0 | 0 | 0 |
For convenience, I pass your dataframe to pandas, but I will use simple iterative structures that you can integrate into spark.
pandas_df = df.toPandas()
We retrieve the list of years excluding the first 3 columns:
years = list(pandas_df.columns)[3:]
Finally, the code needed to generate the required dataframe is as follows (inline comments):
tmp_df_data_list = []
# iterate over rows of df
for _, row in pandas_df.iterrows():
flagged_years=[]
# for each year check if col value is 1
for y in years:
if row[y]: # if is 1, append col name
flagged_years.append(y)
if len(flagged_years) >= 2:
# get first occurence as 'year' and last as 'yearend' by removing the first letter
min_year = flagged_years[0][1:]
max_year = flagged_years[-1][1:]
tmp_df_data_list.append([row.id, '10/10/' min_year, '12/12/' max_year])
res_df = pd.DataFrame(tmp_df_data_list, columns=['id', 'year', 'yearend'])
Output will be:
id | year | yearend | |
---|---|---|---|
0 | 1 | 10/10/2002 | 12/12/2005 |
1 | 2 | 10/10/2001 | 12/12/2005 |
CodePudding user response:
we can use pyspark native functions to create an array of the column names that have the value 1
. the array can then be used to get the min
and max
of years and concat
with "10/10/"
.
here's an example
data_ls = [
("a", "world", "1", 0, 0, 1,0,0,1),
("b", "world", "2", 0, 1, 0,1,0,1),
("c", "world", "3", 0, 0, 0,0,0,0)
]
data_sdf = spark.sparkContext.parallelize(data_ls). \
toDF(['a', 'b', 'id', 'm2000', 'm2001', 'm2002', 'm2003', 'm2004', 'm2005'])
# --- ----- --- ----- ----- ----- ----- ----- -----
# | a| b| id|m2000|m2001|m2002|m2003|m2004|m2005|
# --- ----- --- ----- ----- ----- ----- ----- -----
# | a|world| 1| 0| 0| 1| 0| 0| 1|
# | b|world| 2| 0| 1| 0| 1| 0| 1|
# | c|world| 3| 0| 0| 0| 0| 0| 0|
# --- ----- --- ----- ----- ----- ----- ----- -----
yearcols = [k for k in data_sdf.columns if k.startswith('m20')]
data_sdf. \
withColumn('yearcol_structs',
func.array(*[func.struct(func.lit(int(c[-4:])).alias('year'), func.col(c).alias('value'))
for c in yearcols]
)
). \
withColumn('yearcol_1s',
func.expr('transform(filter(yearcol_structs, x -> x.value = 1), f -> f.year)')
). \
filter(func.size('yearcol_1s') >= 1). \
withColumn('year_start', func.concat(func.lit('10/10/'), func.array_min('yearcol_1s'))). \
withColumn('year_end', func.concat(func.lit('10/10/'), func.array_max('yearcol_1s'))). \
show(truncate=False)
# --- ----- --- ----- ----- ----- ----- ----- ----- ------------------------------------------------------------------ ------------------ ---------- ----------
# |a |b |id |m2000|m2001|m2002|m2003|m2004|m2005|yearcol_structs |yearcol_1s |year_start|year_end |
# --- ----- --- ----- ----- ----- ----- ----- ----- ------------------------------------------------------------------ ------------------ ---------- ----------
# |a |world|1 |0 |0 |1 |0 |0 |1 |[{2000, 0}, {2001, 0}, {2002, 1}, {2003, 0}, {2004, 0}, {2005, 1}]|[2002, 2005] |10/10/2002|10/10/2005|
# |b |world|2 |0 |1 |0 |1 |0 |1 |[{2000, 0}, {2001, 1}, {2002, 0}, {2003, 1}, {2004, 0}, {2005, 1}]|[2001, 2003, 2005]|10/10/2001|10/10/2005|
# --- ----- --- ----- ----- ----- ----- ----- ----- ------------------------------------------------------------------ ------------------ ---------- ----------