I am trying to generate window aggregates for my data. However, it is taking too much time for lags > 20. I am running it in Databricks.
My data has columns: userid, date, orders, total_spend
------ ---------- ------ -----------
|userid|date |orders|total_spend|
------ ---------- ------ -----------
|1 |2022-05-01| 2 | 1000 |
|1 |2022-05-02| 3 | 2000 |
|2 |2022-05-01| 1 | 2000 |
|3 |2022-05-01| 2 | 2000 |
|3 |2022-05-02| 4 | 3000 |
|4 |2022-05-01| 1 | 400 |
|5 |2022-05-01| 2 | 2000 |
|5 |2022-05-02| 4 | 1500 |
|5 |2022-05-02| 2 | 6000 |
from pyspark.sql import functions as F
def getWindow(lag):
return F.window(
F.col("date"),
windowDuration=f"{lag} days",
slideDuration="1 days",
).alias("window")
def getAggregated(df, window, column, lag):
return (
df
.groupBy(F.col("userid"), window)
.agg(
F.avg(F.col(column)).alias(f"mean_{column}_last{lag}days"),
F.sum(F.col(column)).alias(f"sum_{column}_last{lag}days")
)
.withColumn("date", F.date_sub(F.col("window.end").cast("date"), 0))
.drop("window")
)
LAGS = [1, 3, 10, 20, 40, 80, 180]
COLUMNS_TO_BE_AGGREGATED = [
"orders",
"total_spend"
]
df = spark.read.parquet("df_location")
df = df.orderBy("userid", "date")
df.persist()
for col in COLUMNS_TO_BE_AGGREGATED:
for lag in LAGS:
window = getWindow(lag)
agg_df = getAggregated(df, window, col, lag)
df = df.join(agg_df, ["userid", "date"], how="left")
Is there something I am doing incorrectly? Any suggestions on how do I optimize it?
CodePudding user response:
Due to Spark's Lazy evaluation join
multiple times within for-loop makes quite inefficient.
I see you try to use persist
, however, you are modifying the df
immediately in for-loop, so I do not think the persist
provide much benefit here.
If I understand your logic, I think instead of lag
multiple times for entire dataframe, you can leverage the rangeBetween
window function to look up the according window and calculate the mean or sum.
LAGS = [1, 3, 10, 20, 40, 80, 180]
# to use rangeBetween, you need long data type.
df = df.withColumn('ts', F.to_timestamp('date').cast('long'))
w = Window.partitionBy('userid').orderBy('ts')
aday = 24 * 60 * 60
df = (df.select('*',
*[F.avg('orders').over(w.rangeBetween(-x * aday, -1)).alias(f'mean_orders_last_{x}days') for x in LAGS],
*[F.sum('orders').over(w.rangeBetween(-x * aday, -1)).alias(f'sum_orders_last_{x}days') for x in LAGS],
*[F.avg('total_spend').over(w.rangeBetween(-x * aday, -1)).alias(f'mean_total_spend_last_{x}days') for x in LAGS],
*[F.sum('total_spend').over(w.rangeBetween(-x * aday, -1)).alias(f'sum_total_spend_last_{x}days') for x in LAGS]))
Additionally, you can nest the loop for orders
and total_spend
in list comprehension but try to use select
instead of join
in this case.
ref: rangeBetween
vs rowsBetween
What is the difference between rowsBetween and rangeBetween?
CodePudding user response:
If you want speed you want to use the whole cluster and not just 1 node. When you run a window it pulls all of the data on to 1 node to do processing. If you don't like the speed issue get rid of the window and use join
/groupBy
instead. It will enable you to use the entire cluster.
I suggest using something like this to create your groups:
spark.sql("
with today as
(select
current_date() as now,
date_add(current_date(), 100) as onehundred
)
select
explode(
sequence(
today.now,
today.onehundred,
INTERVAL 1 DAY)
) as date,
100 as grouping
from today").show
---------- --------
| date|grouping|
---------- --------
|2022-05-11| 100|
|2022-05-12| 100|
|2022-05-13| 100|
|2022-05-14| 100|
|2022-05-15| 100|
|2022-05-16| 100|
|2022-05-17| 100|
|2022-05-18| 100|
|2022-05-19| 100|
|2022-05-20| 100|
|2022-05-21| 100|
|2022-05-22| 100|
|2022-05-23| 100|
|2022-05-24| 100|
|2022-05-25| 100|
|2022-05-26| 100|
|2022-05-27| 100|
|2022-05-28| 100|
|2022-05-29| 100|
|2022-05-30| 100|
---------- --------
This you could then be used to calculate your 100 day total with a join /groupby and a similar table to create your other totals. I bet it will actually be a super fast join compared to using a window.