Home > OS >  OOM using Spark window function with 30 days interval
OOM using Spark window function with 30 days interval

Time:07-03

I have this data frame:

df = (
spark
.createDataFrame([
    [20210101, 'A', 103, "abc"], 
    [20210101, 'A', 102, "def"], 
    [20210101, 'A', 101, "def"], 
    [20210102, 'A', 34, "ghu"], 
    [20210101, 'B', 180, "xyz"], 
    [20210102, 'B', 123, "kqt"]
]
    ).toDF("txn_date", "txn_type", "txn_amount", "other_attributes")
)

Each date has multiple transactions of each of the different types. My task is to compute the standard deviation of the amount for each record (for the same type and going back 30 days).

The most obvious approach (that I tried) is to create a window based on type and include records going back to past 30 days.

days = lambda i: i * 86400
win = Window.partitionBy("txn_type").orderBy(F.col("txn_date").cast(LongType())).rangeBetween(-days(30), 0)
df = df.withColumn("stddev_last_30days", F.stddev(F.col("txn_amount")).over(win))

Since some of the transaction types have millions of transactions per day, this runs into OOM.

I tried doing it in parts (take only few records for each date at a time) but this leads to error prone calculations since standard deviation is not additive.

I also tried 'collect_set' for all records for a transaction type and date (so all amounts come in as an array in one column), but this runs into OOM as well.

I tried processing one month at a time (I need at a minimum 2 months data since I need to go back 1 month) but even that overwhelms my executors.

What would be a scalable way to solve this problem?

Notes:

  • In the original data, column txn_date is stored as long in "yyyyMMdd" format.

  • There are other columns in the data frame that may or may not be same for each date and type. I haven't included them in the sample code for simplicity.

CodePudding user response:

Filtering

It's always good to remove data which is not needed. You said you need just last 60 days, so You could filter out what's not needed.
This line would keep only rows with date not older than 60 last days (until today):

df = df.filter(F.to_date('txn_date', 'yyyyMMdd').between(F.current_date()-61, F.current_date()))

I'll not use it now in order to illustrate other issues.

Window

The first simple thing, if it's already in long format, you don't need to cast to long again, so we can remove .cast(LongType()).

The other, big thing, is that your window's lower bound is wrong. Look, let's add one more line to the input:

[19990101, 'B', 9999999, "xxxxxxx"],

The line represents the date from the year 1999. After the line was added, running the code, we get this:

#  -------- -------- ---------- ---------------- ------------------ 
# |txn_date|txn_type|txn_amount|other_attributes|stddev_last_30days|
#  -------- -------- ---------- ---------------- ------------------ 
# |20210101|       A|       103|             abc|               1.0|
# |20210101|       A|       102|             def|               1.0|
# |20210101|       A|       101|             def|               1.0|
# |20210102|       A|        34|             ghu|34.009802508492555|
# |19990101|       B|   9999999|         xxxxxxx|              null|
# |20210101|       B|       180|             xyz|  7070939.82553808|
# |20210102|       B|       123|             kqt|  5773414.64605055|
#  -------- -------- ---------- ---------------- ------------------ 

You can see that stddev for 2021 year lines was also affected, so 30 day window does not work, your window actually takes all the data it can. We can check what is the lower bound for date 20210101:

print(20210101-days(30))  # Returns 17618101 - I doubt you wanted this date as lower bound

Probably this was your biggest problem. You should never try to outsmart dates and times. Always use functions specialized for dates and times.

You can use this window:

days = lambda i: i * 86400
w = Window.partitionBy('txn_type').orderBy(F.unix_timestamp(F.col('txn_date').cast('string'), 'yyyyMMdd')).rangeBetween(-days(30), 0)
df = df.withColumn('stddev_last_30days', F.stddev('txn_amount').over(w))

df.show()
#  -------- -------- ---------- ---------------- ------------------ 
# |txn_date|txn_type|txn_amount|other_attributes|stddev_last_30days|
#  -------- -------- ---------- ---------------- ------------------ 
# |20210101|       A|       103|             abc|               1.0|
# |20210101|       A|       102|             def|               1.0|
# |20210101|       A|       101|             def|               1.0|
# |20210102|       A|        34|             ghu|34.009802508492555|
# |19990101|       B|   9999999|         xxxxxxx|              null|
# |20210101|       B|       180|             xyz|              null|
# |20210102|       B|       123|             kqt| 40.30508652763321|
#  -------- -------- ---------- ---------------- ------------------ 

unix_timestamp can transform your 'yyyyMMdd' format into a proper long-format number (UNIX time in seconds). From this, now you can subtract seconds (30 days worth of seconds).

  • Related