Home > Enterprise >  Calculate the product of a column based on previous rows in Spark
Calculate the product of a column based on previous rows in Spark


I have a Spark dataframe and I want to calculate the value of next rows based on 2 columns in the the previous rows. I know how to do it for just 1 row (using the lag() function) but I don't know how to pass on these values in the previous rows to the next several rows.

id | month | value | monthly_increment
1  | 01    | 100   | 2
1  | 02    | 200   | 3
1  | 03    | 600   | 4
1  | 04    | 2400  | 2

As you can see the value of the column "value" gets multiplied by "monthly_increment" and it keeps affecting all the following values for that particular "id".

How can this be done using PySpark?

CodePudding user response:

It's very important to provide example input dataframe when asking Spark questions. You didn't so I made an assumption that your input dataframe looked like this:

from pyspark.sql import functions as F, Window as W
df = spark.createDataFrame(
    [('1', '01',  100, 2),
     ('1', '02', None, 3),
     ('1', '03', None, 4),
     ('1', '04', None, 2)],
    ['id', 'month', 'value', 'monthly_increment'])

Spark 3.2

You could fill missing column "value" values using a combination of product, lag and first window functions:

w = W.partitionBy('id').orderBy('month')
factor = F.product(F.lag('monthly_increment').over(w)).over(w)
df = df.withColumn('value', F.coalesce(F.first('value').over(w) * factor, 'value'))

#  --- ----- ------ ----------------- 
# | id|month| value|monthly_increment|
#  --- ----- ------ ----------------- 
# |  1|   01| 100.0|                2|
# |  1|   02| 200.0|                3|
# |  1|   03| 600.0|                4|
# |  1|   04|2400.0|                2|
#  --- ----- ------ ----------------- 
  • Related