Problem description
I need help with a pyspark.sql function that will create a new variable aggregating records over a specified Window() into a map of key-value pairs.
Reproducible Data
df = spark.createDataFrame(
[
('AK', "2022-05-02", 1651449600, 'US', 3),
('AK', "2022-05-03", 1651536000, 'ON', 1),
('AK', "2022-05-04", 1651622400, 'CO', 1),
('AK', "2022-05-06", 1651795200, 'AK', 1),
('AK', "2022-05-06", 1651795200, 'US', 5)
],
["state", "ds", "ds_num", "region", "count"]
)
df.show()
# ----- ---------- ---------- ------ -----
# |state| ds| ds_num|region|count|
# ----- ---------- ---------- ------ -----
# | AK|2022-05-02|1651449600| US| 3|
# | AK|2022-05-03|1651536000| ON| 1|
# | AK|2022-05-04|1651622400| CO| 1|
# | AK|2022-05-06|1651795200| AK| 1|
# | AK|2022-05-06|1651795200| US| 5|
# ----- ---------- ---------- ------ -----
Partial solutions
Sets of regions over a window frame:
import pyspark.sql.functions as F
from pyspark.sql.window import Window
days = lambda i: i * 86400
df.withColumn('regions_4W',
F.collect_set('region').over(
Window().partitionBy('state').orderBy('ds_num').rangeBetween(-days(27),0)))\
.sort('ds')\
.show()
# ----- ---------- ---------- ------ ----- ----------------
# |state| ds| ds_num|region|count| regions_4W|
# ----- ---------- ---------- ------ ----- ----------------
# | AK|2022-05-02|1651449600| US| 3| [US]|
# | AK|2022-05-03|1651536000| ON| 1| [US, ON]|
# | AK|2022-05-04|1651622400| CO| 1| [CO, US, ON]|
# | AK|2022-05-06|1651795200| AK| 1|[CO, US, ON, AK]|
# | AK|2022-05-06|1651795200| US| 5|[CO, US, ON, AK]|
# ----- ---------- ---------- ------ ----- ----------------
Map of counts per each state and ds
df\
.groupby('state', 'ds', 'ds_num')\
.agg(F.map_from_entries(F.collect_list(F.struct("region", "count"))).alias("count_rolling_4W"))\
.sort('ds')\
.show()
# ----- ---------- ---------- ------------------
# |state| ds| ds_num| count_rolling_4W|
# ----- ---------- ---------- ------------------
# | AK|2022-05-02|1651449600| {US -> 3}|
# | AK|2022-05-03|1651536000| {ON -> 1}|
# | AK|2022-05-04|1651622400| {CO -> 1}|
# | AK|2022-05-06|1651795200|{AK -> 1, US -> 5}|
# ----- ---------- ---------- ------------------
Desired Output
What I need is a map aggregating data per each key present in the specified window
----- ---------- ---------- ------------------------------------
|state| ds| ds_num| count_rolling_4W|
----- ---------- ---------- ------------------------------------
| AK|2022-05-02|1651449600| {US -> 3}|
| AK|2022-05-03|1651536000| {US -> 3, ON -> 1}|
| AK|2022-05-04|1651622400| {US -> 3, ON -> 1, CO -> 1}|
| AK|2022-05-06|1651795200|{US -> 8, ON -> 1, CO -> 1, AK -> 1}|
----- ---------- ---------- ------------------------------------
CodePudding user response:
Assuming that the state
, ds
, ds_num
and region
columns in your source dataframe are unique (they can be seen as primary key), this snipped would do the work:
import pyspark.sql.functions as F
from pyspark.sql.window import Window
days = lambda i: i * 86400
df.alias('a').join(df.alias('b'), 'state') \
.where((F.col('a.ds_num') - F.col('b.ds_num')).between(0, days(27))) \
.select('state', 'a.ds', 'a.ds_num', 'b.region', 'b.count') \
.dropDuplicates() \
.groupBy('state', 'ds', 'ds_num', 'region').sum('count') \
.groupBy('state', 'ds', 'ds_num') \
.agg(F.map_from_entries(F.collect_list(F.struct("region", "sum(count)"))).alias("count_rolling_4W")) \
.orderBy('a.ds') \
.show(truncate=False)
Results:
----- ---------- ---------- ------------------------------------
|state|ds |ds_num |count_rolling_4W |
----- ---------- ---------- ------------------------------------
|AK |2022-05-02|1651449600|{US -> 3} |
|AK |2022-05-03|1651536000|{US -> 3, ON -> 1} |
|AK |2022-05-04|1651622400|{US -> 3, ON -> 1, CO -> 1} |
|AK |2022-05-06|1651795200|{US -> 8, ON -> 1, CO -> 1, AK -> 1}|
----- ---------- ---------- ------------------------------------
It may seem complex, but it's just a windowing rewritten as a cross join for better control over the results.
CodePudding user response:
You can use higher order functions transform
and aggregate
like this:
from pyspark.sql import Window, functions as F
w = Window.partitionBy('state').orderBy('ds_num').rowsBetween(-days(27), 0)
df1 = (df.withColumn('regions', F.collect_set('region').over(w))
.withColumn('counts', F.collect_list(F.struct('region', 'count')).over(w))
.withColumn('counts',
F.transform(
'regions',
lambda x: F.aggregate(
F.filter('counts', lambda y: y['region'] == x),
F.lit(0),
lambda acc, v: acc v['count']
)
))
.withColumn('count_rolling_4W', F.map_from_arrays('regions', 'counts'))
.drop('counts', 'regions')
)
df1.show(truncate=False)
# ----- ---------- ---------- ------ ----- ------------------------------------
# |state|ds |ds_num |region|count|count_rolling_4W |
# ----- ---------- ---------- ------ ----- ------------------------------------
# |AK |2022-05-02|1651449600|US |3 |{US -> 3} |
# |AK |2022-05-03|1651536000|ON |1 |{US -> 3, ON -> 1} |
# |AK |2022-05-04|1651622400|CO |1 |{CO -> 1, US -> 3, ON -> 1} |
# |AK |2022-05-06|1651795200|AK |1 |{CO -> 1, US -> 3, ON -> 1, AK -> 1}|
# |AK |2022-05-06|1651795200|US |5 |{CO -> 1, US -> 8, ON -> 1, AK -> 1}|
# ----- ---------- ---------- ------ ----- ------------------------------------
CodePudding user response:
Great question. This method will use 2 windows and 2 higher order functions (aggregate
and map_from_entries
)
from pyspark.sql import functions as F, Window
w1 = Window.partitionBy('state', 'region').orderBy('ds')
w2 = Window.partitionBy('state').orderBy('ds')
df = df.withColumn('acc_count', F.sum('count').over(w1))
df = df.withColumn('maps', F.collect_set(F.struct('region', 'acc_count')).over(w2))
df = df.groupBy('state', 'ds', 'ds_num').agg(
F.aggregate(
F.first('maps'),
F.create_map(F.first('region'), F.first('acc_count')),
lambda m, x: F.map_zip_with(m, F.map_from_entries(F.array(x)), lambda k, v1, v2: F.greatest(v2, v1))
).alias('count_rolling_4W')
)
df.show(truncate=0)
# ----- ---------- ---------- ------------------------------------
# |state|ds |ds_num |count_rolling_4W |
# ----- ---------- ---------- ------------------------------------
# |AK |2022-05-02|1651449600|{US -> 3} |
# |AK |2022-05-03|1651536000|{ON -> 1, US -> 3} |
# |AK |2022-05-04|1651622400|{CO -> 1, US -> 3, ON -> 1} |
# |AK |2022-05-06|1651795200|{AK -> 1, US -> 8, ON -> 1, CO -> 1}|
# ----- ---------- ---------- ------------------------------------