I have a Spark dataframe and I want to calculate the value of next rows based on 2 columns in the the previous rows. I know how to do it for just 1 row (using the lag()
function) but I don't know how to pass on these values in the previous rows to the next several rows.
id | month | value | monthly_increment
1 | 01 | 100 | 2
1 | 02 | 200 | 3
1 | 03 | 600 | 4
1 | 04 | 2400 | 2
As you can see the value of the column "value" gets multiplied by "monthly_increment" and it keeps affecting all the following values for that particular "id".
How can this be done using PySpark?
CodePudding user response:
It's very important to provide example input dataframe when asking Spark questions. You didn't so I made an assumption that your input dataframe looked like this:
from pyspark.sql import functions as F, Window as W
df = spark.createDataFrame(
[('1', '01', 100, 2),
('1', '02', None, 3),
('1', '03', None, 4),
('1', '04', None, 2)],
['id', 'month', 'value', 'monthly_increment'])
Spark 3.2
You could fill missing column "value" values using a combination of product
, lag
and first
window functions:
w = W.partitionBy('id').orderBy('month')
factor = F.product(F.lag('monthly_increment').over(w)).over(w)
df = df.withColumn('value', F.coalesce(F.first('value').over(w) * factor, 'value'))
df.show()
# --- ----- ------ -----------------
# | id|month| value|monthly_increment|
# --- ----- ------ -----------------
# | 1| 01| 100.0| 2|
# | 1| 02| 200.0| 3|
# | 1| 03| 600.0| 4|
# | 1| 04|2400.0| 2|
# --- ----- ------ -----------------