Home > Net >  How to create a map column with rolling window aggregates per each key
How to create a map column with rolling window aggregates per each key

Time:06-28

Problem description

I need help with a pyspark.sql function that will create a new variable aggregating records over a specified Window() into a map of key-value pairs.

Reproducible Data

df = spark.createDataFrame(
    [
        ('AK', "2022-05-02", 1651449600, 'US', 3), 
        ('AK', "2022-05-03", 1651536000, 'ON', 1),
        ('AK', "2022-05-04", 1651622400, 'CO', 1),
        ('AK', "2022-05-06", 1651795200, 'AK', 1),
        ('AK', "2022-05-06", 1651795200, 'US', 5)
    ],
    ["state", "ds", "ds_num", "region", "count"]
)

df.show()
#  ----- ---------- ---------- ------ ----- 
# |state|        ds|    ds_num|region|count|
#  ----- ---------- ---------- ------ ----- 
# |   AK|2022-05-02|1651449600|    US|    3|
# |   AK|2022-05-03|1651536000|    ON|    1|
# |   AK|2022-05-04|1651622400|    CO|    1|
# |   AK|2022-05-06|1651795200|    AK|    1|
# |   AK|2022-05-06|1651795200|    US|    5|
#  ----- ---------- ---------- ------ ----- 

Partial solutions

Sets of regions over a window frame:

import pyspark.sql.functions as F
from pyspark.sql.window import Window

days = lambda i: i * 86400

df.withColumn('regions_4W', 
              F.collect_set('region').over(
                  Window().partitionBy('state').orderBy('ds_num').rangeBetween(-days(27),0)))\
.sort('ds')\
.show()

#  ----- ---------- ---------- ------ ----- ---------------- 
# |state|        ds|    ds_num|region|count|      regions_4W|
#  ----- ---------- ---------- ------ ----- ---------------- 
# |   AK|2022-05-02|1651449600|    US|    3|            [US]|
# |   AK|2022-05-03|1651536000|    ON|    1|        [US, ON]|
# |   AK|2022-05-04|1651622400|    CO|    1|    [CO, US, ON]|
# |   AK|2022-05-06|1651795200|    AK|    1|[CO, US, ON, AK]|
# |   AK|2022-05-06|1651795200|    US|    5|[CO, US, ON, AK]|
#  ----- ---------- ---------- ------ ----- ---------------- 

Map of counts per each state and ds

df\
.groupby('state', 'ds', 'ds_num')\
.agg(F.map_from_entries(F.collect_list(F.struct("region", "count"))).alias("count_rolling_4W"))\
.sort('ds')\
.show()

#  ----- ---------- ---------- ------------------ 
# |state|        ds|    ds_num|  count_rolling_4W|
#  ----- ---------- ---------- ------------------ 
# |   AK|2022-05-02|1651449600|         {US -> 3}|
# |   AK|2022-05-03|1651536000|         {ON -> 1}|
# |   AK|2022-05-04|1651622400|         {CO -> 1}|
# |   AK|2022-05-06|1651795200|{AK -> 1, US -> 5}|
#  ----- ---------- ---------- ------------------ 

Desired Output

What I need is a map aggregating data per each key present in the specified window

 ----- ---------- ---------- ------------------------------------ 
|state|        ds|    ds_num|                    count_rolling_4W|
 ----- ---------- ---------- ------------------------------------ 
|   AK|2022-05-02|1651449600|                           {US -> 3}|
|   AK|2022-05-03|1651536000|                  {US -> 3, ON -> 1}|
|   AK|2022-05-04|1651622400|         {US -> 3, ON -> 1, CO -> 1}|
|   AK|2022-05-06|1651795200|{US -> 8, ON -> 1, CO -> 1, AK -> 1}|
 ----- ---------- ---------- ------------------------------------ 

CodePudding user response:

Assuming that the state, ds, ds_num and region columns in your source dataframe are unique (they can be seen as primary key), this snipped would do the work:

import pyspark.sql.functions as F
from pyspark.sql.window import Window

days = lambda i: i * 86400

df.alias('a').join(df.alias('b'), 'state') \
    .where((F.col('a.ds_num') - F.col('b.ds_num')).between(0, days(27))) \
    .select('state', 'a.ds', 'a.ds_num', 'b.region', 'b.count') \
    .dropDuplicates() \
    .groupBy('state', 'ds', 'ds_num', 'region').sum('count') \
    .groupBy('state', 'ds', 'ds_num') \
    .agg(F.map_from_entries(F.collect_list(F.struct("region", "sum(count)"))).alias("count_rolling_4W")) \
    .orderBy('a.ds') \
    .show(truncate=False)

Results:

 ----- ---------- ---------- ------------------------------------ 
|state|ds        |ds_num    |count_rolling_4W                    |
 ----- ---------- ---------- ------------------------------------ 
|AK   |2022-05-02|1651449600|{US -> 3}                           |
|AK   |2022-05-03|1651536000|{US -> 3, ON -> 1}                  |
|AK   |2022-05-04|1651622400|{US -> 3, ON -> 1, CO -> 1}         |
|AK   |2022-05-06|1651795200|{US -> 8, ON -> 1, CO -> 1, AK -> 1}|
 ----- ---------- ---------- ------------------------------------ 

It may seem complex, but it's just a windowing rewritten as a cross join for better control over the results.

CodePudding user response:

You can use higher order functions transform and aggregate like this:

from pyspark.sql import Window, functions as F

w = Window.partitionBy('state').orderBy('ds_num').rowsBetween(-days(27), 0)

df1 = (df.withColumn('regions', F.collect_set('region').over(w))
       .withColumn('counts', F.collect_list(F.struct('region', 'count')).over(w))
       .withColumn('counts',
                   F.transform(
                       'regions',
                       lambda x: F.aggregate(
                           F.filter('counts', lambda y: y['region'] == x),
                           F.lit(0),
                           lambda acc, v: acc   v['count']
                       )
                   ))
       .withColumn('count_rolling_4W', F.map_from_arrays('regions', 'counts'))
       .drop('counts', 'regions')
       )

df1.show(truncate=False)

# ----- ---------- ---------- ------ ----- ------------------------------------ 
# |state|ds        |ds_num    |region|count|count_rolling_4W                    |
#  ----- ---------- ---------- ------ ----- ------------------------------------ 
# |AK   |2022-05-02|1651449600|US    |3    |{US -> 3}                           |
# |AK   |2022-05-03|1651536000|ON    |1    |{US -> 3, ON -> 1}                  |
# |AK   |2022-05-04|1651622400|CO    |1    |{CO -> 1, US -> 3, ON -> 1}         |
# |AK   |2022-05-06|1651795200|AK    |1    |{CO -> 1, US -> 3, ON -> 1, AK -> 1}|
# |AK   |2022-05-06|1651795200|US    |5    |{CO -> 1, US -> 8, ON -> 1, AK -> 1}|
#  ----- ---------- ---------- ------ ----- ------------------------------------ 

CodePudding user response:

Great question. This method will use 2 windows and 2 higher order functions (aggregate and map_from_entries)

from pyspark.sql import functions as F, Window

w1 = Window.partitionBy('state', 'region').orderBy('ds')
w2 = Window.partitionBy('state').orderBy('ds')
df = df.withColumn('acc_count', F.sum('count').over(w1))
df = df.withColumn('maps', F.collect_set(F.struct('region', 'acc_count')).over(w2))
df = df.groupBy('state', 'ds', 'ds_num').agg(
    F.aggregate(
        F.first('maps'),
        F.create_map(F.first('region'), F.first('acc_count')),
        lambda m, x: F.map_zip_with(m, F.map_from_entries(F.array(x)), lambda k, v1, v2: F.greatest(v2, v1))
    ).alias('count_rolling_4W')
)

df.show(truncate=0)
#  ----- ---------- ---------- ------------------------------------ 
# |state|ds        |ds_num    |count_rolling_4W                    |
#  ----- ---------- ---------- ------------------------------------ 
# |AK   |2022-05-02|1651449600|{US -> 3}                           |
# |AK   |2022-05-03|1651536000|{ON -> 1, US -> 3}                  |
# |AK   |2022-05-04|1651622400|{CO -> 1, US -> 3, ON -> 1}         |
# |AK   |2022-05-06|1651795200|{AK -> 1, US -> 8, ON -> 1, CO -> 1}|
#  ----- ---------- ---------- ------------------------------------ 
  • Related