I have a PySpark dataframe which looks like this, I have a map datatype column Map<Str,Int>
Date Item (Map<Str,int>) Total Items ColA
2021-02-01 Item_A -> 3, Item_B -> 10, Item_C -> 2 15 10
2021-02-02 Item_A -> 1, Item_D -> 5, Item_E -> 7 13 20
2021-02-03 Item_A -> 8, Item_E -> 3, Item_C -> 1 12 30
I want to sum of all the columns including the map column. For map column the sum should be calculated based on keys.
I want something like this:
[[Item_A -> 12, Item_B -> 10, Item_C -> 3, Item_D -> 5, Item_E -> 10], 40, 60]
Not necessarily a list of lists, but I want the sum of the columns.
My approach:
df.rdd.map(lambda x: (1,x[1])).reduceByKey(lambda x,y: x y).collect()[0][1]
CodePudding user response:
You can do aggregations for map column and for other columns separately, as you would need an explode
on Items column, and then other columns which you need to sum would become hard to deal with.
Example dataframe:
from pyspark.sql import functions as F
df = spark.createDataFrame(
[('2021-02-01', 'Item_A', 3, 'Item_B', 10, 'Item_C', 2, 15, 10),
('2021-02-02', 'Item_A', 1, 'Item_D', 5, 'Item_E', 7, 13, 20),
('2021-02-03', 'Item_A', 8, 'Item_E', 3, 'Item_C', 1, 12, 30)],
['Date', 'q', 'w', 'e', 'r', 't', 'y', 'Total Items', 'ColA'])
df = df.select('Date', F.create_map('q', 'w', 'e', 'r', 't', 'y').alias('Item'), 'Total Items', 'ColA')
df.show(truncate=0)
# ---------- ---------------------------------------- ----------- ----
# |Date |Item |Total Items|ColA|
# ---------- ---------------------------------------- ----------- ----
# |2021-02-01|{Item_A -> 3, Item_B -> 10, Item_C -> 2}|15 |10 |
# |2021-02-02|{Item_A -> 1, Item_D -> 5, Item_E -> 7} |13 |20 |
# |2021-02-03|{Item_A -> 8, Item_E -> 3, Item_C -> 1} |12 |30 |
# ---------- ---------------------------------------- ----------- ----
Script:
aggs = df.agg(F.sum('Total Items'), F.sum('ColA')).head()
df = (df
.select('*', F.explode('Item'))
.groupBy('key')
.agg(F.sum('value').alias('value'))
.select(
F.map_from_entries(F.collect_set(F.struct('key', 'value'))).alias('Item'),
F.lit(aggs[0]).alias('Total Items'),
F.lit(aggs[1]).alias('ColA'),
)
)
df.show(truncate=0)
# -------------------------------------------------------------------- ----------- ----
# |Item |Total Items|ColA|
# -------------------------------------------------------------------- ----------- ----
# |{Item_C -> 3, Item_E -> 10, Item_A -> 12, Item_B -> 10, Item_D -> 5}|40 |60 |
# -------------------------------------------------------------------- ----------- ----