Home > database >  Get sum of columns from a dataframe including map column - PySpark
Get sum of columns from a dataframe including map column - PySpark

Time:07-01

I have a PySpark dataframe which looks like this, I have a map datatype column Map<Str,Int>

Date              Item (Map<Str,int>)                           Total Items      ColA

2021-02-01    Item_A -> 3, Item_B -> 10, Item_C -> 2                 15            10
2021-02-02    Item_A -> 1, Item_D -> 5, Item_E ->  7                 13            20
2021-02-03    Item_A -> 8, Item_E -> 3, Item_C ->  1                 12            30

I want to sum of all the columns including the map column. For map column the sum should be calculated based on keys.

I want something like this:

[[Item_A -> 12, Item_B -> 10, Item_C -> 3, Item_D -> 5, Item_E -> 10], 40, 60]

Not necessarily a list of lists, but I want the sum of the columns.

My approach:

df.rdd.map(lambda x: (1,x[1])).reduceByKey(lambda x,y: x   y).collect()[0][1]

CodePudding user response:

You can do aggregations for map column and for other columns separately, as you would need an explode on Items column, and then other columns which you need to sum would become hard to deal with.

Example dataframe:

from pyspark.sql import functions as F

df = spark.createDataFrame(
    [('2021-02-01', 'Item_A', 3, 'Item_B', 10, 'Item_C', 2, 15, 10),
     ('2021-02-02', 'Item_A', 1, 'Item_D', 5, 'Item_E', 7, 13, 20),
     ('2021-02-03', 'Item_A', 8, 'Item_E', 3, 'Item_C', 1, 12, 30)],
    ['Date', 'q', 'w', 'e', 'r', 't', 'y', 'Total Items', 'ColA'])
df = df.select('Date', F.create_map('q', 'w', 'e', 'r', 't', 'y').alias('Item'), 'Total Items', 'ColA')
df.show(truncate=0)
#  ---------- ---------------------------------------- ----------- ---- 
# |Date      |Item                                    |Total Items|ColA|
#  ---------- ---------------------------------------- ----------- ---- 
# |2021-02-01|{Item_A -> 3, Item_B -> 10, Item_C -> 2}|15         |10  |
# |2021-02-02|{Item_A -> 1, Item_D -> 5, Item_E -> 7} |13         |20  |
# |2021-02-03|{Item_A -> 8, Item_E -> 3, Item_C -> 1} |12         |30  |
#  ---------- ---------------------------------------- ----------- ---- 

Script:

aggs = df.agg(F.sum('Total Items'), F.sum('ColA')).head()

df = (df
    .select('*', F.explode('Item'))
    .groupBy('key')
    .agg(F.sum('value').alias('value'))
    .select(
        F.map_from_entries(F.collect_set(F.struct('key', 'value'))).alias('Item'),
        F.lit(aggs[0]).alias('Total Items'),
        F.lit(aggs[1]).alias('ColA'),
    )
)
df.show(truncate=0)
#  -------------------------------------------------------------------- ----------- ---- 
# |Item                                                                |Total Items|ColA|
#  -------------------------------------------------------------------- ----------- ---- 
# |{Item_C -> 3, Item_E -> 10, Item_A -> 12, Item_B -> 10, Item_D -> 5}|40         |60  |
#  -------------------------------------------------------------------- ----------- ---- 
  • Related