Home > other >  Spark Window Function Null Skew
Spark Window Function Null Skew

Time:09-20

Recently I've encountered an issue running one of our PySpark jobs. While analyzing the stages in Spark UI I have noticed that the longest running stage takes 1.2 hours to run out of the total 2.5 hours that takes for the entire process to run.

SparkUI stages tab sorted by longest duration

Once I took a look at the stage details it was clear that I'm facing a severe data skew, causing a single task to run for the entire 1.2 hours while all other tasks finish within 23 seconds.

Tasks distribution shows a very clear skew

The summary shows the big difference between the longest task to the vast majority

The DAG showed this stage involves Window Functions which helped me to quickly narrow down the problematic area to a few queries and finding the root cause -> The column, account, that was being used in the Window.partitionBy("account") had 25% of null values. I don't have an interest to calculate the sum for the null accounts though I do need the involved rows for further calculations therefore I can't filter them out prior the window function.

Here is my window function query:

problematic_account_window = Window.partitionBy("account")

sales_with_account_total_df = sales_df.withColumn("sum_sales_per_account", sum(col("price")).over(problematic_account_window))

So we found the one to blame - What can we do now? How can we resolve the skew and the performance issue?

CodePudding user response:

We basically have 2 solutions for this issue:

  1. Break the initial dataframe to 2 different dataframes, one that filters out the null values and calculates the sum on, and the second that contains only the null values and is not part of the calculation. Lastly we union the two together.
  2. Apply salting technique on the null values in order to spread the nulls on all partitions and provide stability to the stage.

Solution 1:

account_window = Window.partitionBy("account")

# split to null and non null
non_null_accounts_df = sales_df.where(col("account").isNotNull())
only_null_accounts_df = sales_df.where(col("account").isNull())

# calculate the sum for the non null
sales_with_non_null_accounts_df = non_null_accounts_df.withColumn("sum_sales_per_account", sum(col("price")).over(account_window)

# union the calculated result and the non null df to the final result
sales_with_account_total_df = sales_with_non_null_accounts_df.unionByName(only_null_accounts_df, allowMissingColumns=True)

Solution 2:

SPARK_SHUFFLE_PARTITIONS = spark.conf.get("spark.sql.shuffle.partitions")

modified_sales_df = (sales_df
    # create a random partition value that spans as much as number of shuffle partitions
    .withColumn("random_salt_partition", lit(ceil(rand() * SPARK_SHUFFLE_PARTITIONS)))
    # use the random partition values only in case the account value is null
    .withColumn("salted_account", coalesce(col("account"), col("random_salt_partition")))
    )

# modify the partition to use the salted account
salted_account_window = Window.partitionBy("salted_account")

# use the salted account window to calculate the sum of sales
sales_with_account_total_df = sales_df.withColumn("sum_sales_per_account", sum(col("price")).over(salted_account_window))

In my solution I've decided to use solution 2 since it didn't force me to create more dataframes for the sake of the calculation, and here is the result:

Tasks are now fairly even

Max duration for tasks now stands on 1.2 minutes instead of 1.2 hours

As seen above the salting technique helped resolving the skewness. The exact same stage now runs for a total of 5.5 minutes instead of 1.2 hours. The only modification in the code was the salting column in the partitionBy. The comparison shown is based on the exact same cluster/nodes amount/cluster config.

  • Related