Home > database >  Optimizing Spark code for Window Aggregation
Optimizing Spark code for Window Aggregation

Time:05-12

I am trying to generate window aggregates for my data. However, it is taking too much time for lags > 20. I am running it in Databricks.

My data has columns: userid, date, orders, total_spend

 ------ ---------- ------ ----------- 
|userid|date      |orders|total_spend|
 ------ ---------- ------ ----------- 
|1     |2022-05-01| 2    | 1000      |
|1     |2022-05-02| 3    | 2000      |
|2     |2022-05-01| 1    | 2000      |
|3     |2022-05-01| 2    | 2000      |
|3     |2022-05-02| 4    | 3000      |
|4     |2022-05-01| 1    | 400       |
|5     |2022-05-01| 2    | 2000      |
|5     |2022-05-02| 4    | 1500      |
|5     |2022-05-02| 2    | 6000      |
from pyspark.sql import functions as F

def getWindow(lag):
    return F.window(
        F.col("date"),
        windowDuration=f"{lag} days",
        slideDuration="1 days",
    ).alias("window")

def getAggregated(df, window, column, lag): 
    return (
      df
      .groupBy(F.col("userid"), window)
      .agg(
        F.avg(F.col(column)).alias(f"mean_{column}_last{lag}days"),
        F.sum(F.col(column)).alias(f"sum_{column}_last{lag}days")
      )
      .withColumn("date", F.date_sub(F.col("window.end").cast("date"), 0))
      .drop("window")
    )

LAGS = [1, 3, 10, 20, 40, 80, 180]
COLUMNS_TO_BE_AGGREGATED = [
    "orders",
    "total_spend"
]

df = spark.read.parquet("df_location")
df = df.orderBy("userid", "date")
df.persist()

for col in COLUMNS_TO_BE_AGGREGATED:    
    for lag in LAGS:
        window = getWindow(lag)
        agg_df = getAggregated(df, window, col, lag)
        df = df.join(agg_df, ["userid", "date"], how="left")

Is there something I am doing incorrectly? Any suggestions on how do I optimize it?

CodePudding user response:

Due to Spark's Lazy evaluation join multiple times within for-loop makes quite inefficient. I see you try to use persist, however, you are modifying the df immediately in for-loop, so I do not think the persist provide much benefit here.

If I understand your logic, I think instead of lag multiple times for entire dataframe, you can leverage the rangeBetween window function to look up the according window and calculate the mean or sum.

LAGS = [1, 3, 10, 20, 40, 80, 180]

# to use rangeBetween, you need long data type.
df = df.withColumn('ts', F.to_timestamp('date').cast('long'))

w = Window.partitionBy('userid').orderBy('ts')

aday = 24 * 60 * 60

df = (df.select('*',
                *[F.avg('orders').over(w.rangeBetween(-x * aday, -1)).alias(f'mean_orders_last_{x}days') for x in LAGS],
                *[F.sum('orders').over(w.rangeBetween(-x * aday, -1)).alias(f'sum_orders_last_{x}days') for x in LAGS],
                *[F.avg('total_spend').over(w.rangeBetween(-x * aday, -1)).alias(f'mean_total_spend_last_{x}days') for x in LAGS],
                *[F.sum('total_spend').over(w.rangeBetween(-x * aday, -1)).alias(f'sum_total_spend_last_{x}days') for x in LAGS]))

Additionally, you can nest the loop for orders and total_spend in list comprehension but try to use select instead of join in this case.

ref: rangeBetween vs rowsBetween What is the difference between rowsBetween and rangeBetween?

CodePudding user response:

If you want speed you want to use the whole cluster and not just 1 node. When you run a window it pulls all of the data on to 1 node to do processing. If you don't like the speed issue get rid of the window and use join/groupBy instead. It will enable you to use the entire cluster.

I suggest using something like this to create your groups:

spark.sql("
  with today as 
    (select 
       current_date() as now, 
       date_add(current_date(), 100) as onehundred 
    ) 
  select 
    explode(
      sequence(
        today.now,
        today.onehundred,
        INTERVAL 1 DAY)
      ) as date, 
    100 as grouping 
  from today").show
 ---------- -------- 
|      date|grouping|
 ---------- -------- 
|2022-05-11|     100|
|2022-05-12|     100|
|2022-05-13|     100|
|2022-05-14|     100|
|2022-05-15|     100|
|2022-05-16|     100|
|2022-05-17|     100|
|2022-05-18|     100|
|2022-05-19|     100|
|2022-05-20|     100|
|2022-05-21|     100|
|2022-05-22|     100|
|2022-05-23|     100|
|2022-05-24|     100|
|2022-05-25|     100|
|2022-05-26|     100|
|2022-05-27|     100|
|2022-05-28|     100|
|2022-05-29|     100|
|2022-05-30|     100|
 ---------- -------- 

This you could then be used to calculate your 100 day total with a join /groupby and a similar table to create your other totals. I bet it will actually be a super fast join compared to using a window.

  • Related