Home > Enterprise >  Get last N month average in pyspark
Get last N month average in pyspark

Time:09-11

I have a pyspark dataframe with sales dataset for 4 months where I need to calculate last 3 month sales average.

This is what it is supposed to look like. So for September for a given sales rep, L3M should be Sales of (June July August) / 3, if for a given rep there's no data entries for say July then it should just be (June August) / 2.

How can I best do this in pyspark?

enter image description here

CodePudding user response:

Not sure what's the best solution, but I've tried something like this:

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window

spark = SparkSession.builder.getOrCreate()
data = [
    {"Date": "2022-06-03", "Month": "202206", "Sales": 50, "Store_ID": 555, "Sales_Rep": 211},
    {"Date": "2022-06-19", "Month": "202206", "Sales": 10, "Store_ID": 555, "Sales_Rep": 211},
    {"Date": "2022-06-27", "Month": "202206", "Sales": 60, "Store_ID": 555, "Sales_Rep": 211},
    {"Date": "2022-07-02", "Month": "202207", "Sales": 0, "Store_ID": 555, "Sales_Rep": 211},
    {"Date": "2022-08-05", "Month": "202208", "Sales": 30, "Store_ID": 555, "Sales_Rep": 211},
    {"Date": "2022-08-09", "Month": "202208", "Sales": 10, "Store_ID": 555, "Sales_Rep": 211},
    {"Date": "2022-09-03", "Month": "202209", "Sales": 70, "Store_ID": 555, "Sales_Rep": 211},
    {"Date": "2022-09-07", "Month": "202209", "Sales": 20, "Store_ID": 555, "Sales_Rep": 211},
]
month_window = Window.partitionBy(["Store_ID"]).orderBy("Month_tmp").rangeBetween(-3, -1)

df = spark.createDataFrame(data=data)
df_grouped = df.groupBy(["Store_ID", "Month"]).agg(F.sum("Sales").alias("average"))
df_grouped = df_grouped.withColumn("Month_tmp", F.month(F.to_date("Month", "yyyyMM")))
df_grouped = df_grouped.withColumn("tmp", F.count("average").over(month_window))
df_grouped = df_grouped.withColumn("L3M", F.round(F.sum("average").over(month_window) / F.col("tmp"), 2))
df = (
    df.alias("df")
    .join(df_grouped.alias("df2"), df.Month == df_grouped.Month)
    .select("df.Date", "df.Month", "df.Sales", "df.Sales_Rep", "df.Store_ID", "df2.L3M")
)
df.show(10, False)
df.printSchema()

Which gives:

 ---------- ------ ----- --------- -------- -----                               
|Date      |Month |Sales|Sales_Rep|Store_ID|L3M  |
 ---------- ------ ----- --------- -------- ----- 
|2022-06-03|202206|50   |211      |555     |null |
|2022-06-19|202206|10   |211      |555     |null |
|2022-06-27|202206|60   |211      |555     |null |
|2022-07-02|202207|0    |211      |555     |120.0|
|2022-08-05|202208|30   |211      |555     |60.0 |
|2022-08-09|202208|10   |211      |555     |60.0 |
|2022-09-03|202209|70   |211      |555     |53.33|
|2022-09-07|202209|20   |211      |555     |53.33|
 ---------- ------ ----- --------- -------- ----- 

root
 |-- Date: string (nullable = true)
 |-- Month: string (nullable = true)
 |-- Sales: long (nullable = true)
 |-- Sales_Rep: long (nullable = true)
 |-- Store_ID: long (nullable = true)
 |-- L3M: double (nullable = true)

I'm also not sure how well will this perform on the full data, but it gives you an example of how to use groupBy, rangeBetween, and Window to isolate the data you need.

  • Related