Home > Software engineering >  Pyspark - collapse all columns in dataframe by group variable
Pyspark - collapse all columns in dataframe by group variable

Time:02-26

I have a dataset that looks something like below.

df = spark.createDataFrame(
    [
      ("001A", 105, "foo", to_date("2022-01-01", "yyyy-MM-dd")),
      ("001A", 25, "foo", to_date("2022-01-01", "yyyy-MM-dd")),
      ("002B", 85, "bar", to_date("2022-02-05", "yyyy-MM-dd")),
      ("002B", 15, "bar", to_date("2022-02-05", "yyyy-MM-dd")),
    ],
    ["id", "num_col1", "str_col1", "date_col1"]
)
df.show()
 ---- -------- -------- ---------- 
|  id|num_col1|str_col1| date_col1|
 ---- -------- -------- ---------- 
|001A|     105|     foo|2022-01-01|
|001A|      25|     foo|2022-01-01|
|002B|      85|     bar|2022-01-15|
|002B|      15|     bar|2022-01-15|
 ---- -------- -------- ---------- 

What I wanted to achieve is an aggregated form of df whereby I group by values in id and then aggregate across all the remaining columns in the dataframe. Therefore the resulting dataframe would look like this

 ---- -------- -------- ---------- 
|  id|num_col1|str_col1| date_col1|
 ---- -------- -------- ---------- 
|001A|     130|     foo|2022-01-01|
|002B|     100|     bar|2022-01-15|
 ---- -------- -------- ---------- 

The dataframe contains a mixture of:

  • numeric columns - which need to be summed
  • string colums - which are always the same between groups - so just need to take the existing value
  • date columns - which are also always the same between groups - so just need to take the existing value

The dataframe also contains many, many more columns so any method that involves writing out every single column will not work.

I have looked quite comprehensively across the net but have not found any similar questions or solutions that I have been able to modify to get it to work on my data.

I am quite new to PySpark so my attempts have been pretty futile but I have tried to use the collect_set function to collapse down each row into the groups, with the intention of applying a map function like here Merge multiple spark rows to one, however, it's been pretty unsuccessful.

CodePudding user response:

You can use dtypes to classify, group by string and date type columns, and aggregate numeric columns respectively.

df = df.groupBy(*[t[0] for t in df.dtypes if t[1] in ('string', 'date')]) \
    .agg(*[F.sum(t[0]).alias(t[0]) for t in df.dtypes if t[1] not in ('string', 'date')])
df.printSchema()
df.show(truncate=False)
  • Related