I have a dataset that looks something like below.
df = spark.createDataFrame(
[
("001A", 105, "foo", to_date("2022-01-01", "yyyy-MM-dd")),
("001A", 25, "foo", to_date("2022-01-01", "yyyy-MM-dd")),
("002B", 85, "bar", to_date("2022-02-05", "yyyy-MM-dd")),
("002B", 15, "bar", to_date("2022-02-05", "yyyy-MM-dd")),
],
["id", "num_col1", "str_col1", "date_col1"]
)
df.show()
---- -------- -------- ----------
| id|num_col1|str_col1| date_col1|
---- -------- -------- ----------
|001A| 105| foo|2022-01-01|
|001A| 25| foo|2022-01-01|
|002B| 85| bar|2022-01-15|
|002B| 15| bar|2022-01-15|
---- -------- -------- ----------
What I wanted to achieve is an aggregated form of df
whereby I group by values in id
and then aggregate across all the remaining columns in the dataframe. Therefore the resulting dataframe would look like this
---- -------- -------- ----------
| id|num_col1|str_col1| date_col1|
---- -------- -------- ----------
|001A| 130| foo|2022-01-01|
|002B| 100| bar|2022-01-15|
---- -------- -------- ----------
The dataframe contains a mixture of:
- numeric columns - which need to be summed
- string colums - which are always the same between groups - so just need to take the existing value
- date columns - which are also always the same between groups - so just need to take the existing value
The dataframe also contains many, many more columns so any method that involves writing out every single column will not work.
I have looked quite comprehensively across the net but have not found any similar questions or solutions that I have been able to modify to get it to work on my data.
I am quite new to PySpark so my attempts have been pretty futile but I have tried to use the collect_set
function to collapse down each row into the groups, with the intention of applying a map
function like here Merge multiple spark rows to one, however, it's been pretty unsuccessful.
CodePudding user response:
You can use dtypes
to classify, group by string
and date
type columns, and aggregate numeric columns respectively.
df = df.groupBy(*[t[0] for t in df.dtypes if t[1] in ('string', 'date')]) \
.agg(*[F.sum(t[0]).alias(t[0]) for t in df.dtypes if t[1] not in ('string', 'date')])
df.printSchema()
df.show(truncate=False)