I have a data frame in pyspark
like below
df = spark.createDataFrame([
(124,10,8),
(124,20,7),
(125,30,6),
(125,40,5),
(126,50,4),
(126,60,3),
(126,70,2),
(127,80,1)],("ACC_KEY", "AMT", "value"))
df.show()
------- --- -----
|ACC_KEY|AMT|value|
------- --- -----
| 126| 70| 2|
| 126| 60| 3|
| 126| 50| 4|
| 124| 20| 7|
| 124| 10| 8|
| 127| 80| 1|
| 125| 40| 5|
| 125| 30| 6|
------- --- -----
Expected result
------- --- ----- ------- ----- -------
|ACC_KEY|AMT|value|row_now|amt_c|lkp_rev|
------- --- ----- ------- ----- -------
| 126| 70| 2| 1| 70| 72|
| 126| 60| 3| 2| 72| 75|
| 126| 50| 4| 3| 75| 79|
| 124| 20| 7| 1| 20| 27|
| 124| 10| 8| 2| 27| 35|
| 127| 80| 1| 1| 80| 81|
| 125| 40| 5| 1| 40| 45|
| 125| 30| 6| 2| 45| 51|
------- --- ----- ------- ----- -------
Conditions
1) When row_number = 1 then amt_c column = column AMT
2) when row_number != 1 then It should be the lag of column lkp_rev column value
3) lkp_rev column = amt_c column value column
I have tried like below
import pyspark.sql.functions as f
from pyspark.sql import Window
# create row_number column
df1 = df.withColumn("row_now", f.row_number().over(Window.partitionBy("ACC_KEY").orderBy(f.col('AMT').desc())))
# amt_c column creation
df2 = df1.withColumn("amt_c", f.when(f.col("row_now") == 1, f.col("AMT")).otherwise(f.col("value") f.col("AMT")))
How can I achieve what i want
CodePudding user response:
I figure that it'd be much easier if you separate all rows that have row_now = 1
, and make it a as a "reference" dataframe or starting point for each acc_key
.
First, adding row number so we can reuse later
df = df.withColumn('row_now', F.row_number().over(W.partitionBy('acc_key').orderBy(F.col('amt').desc())))
# ------- --- ----- -------
# |acc_key|amt|value|row_now|
# ------- --- ----- -------
# | 126| 70| 2| 1|
# | 126| 60| 3| 2|
# | 126| 50| 4| 3|
# | 124| 20| 7| 1|
# | 124| 10| 8| 2|
# | 127| 80| 1| 1|
# | 125| 40| 5| 1|
# | 125| 30| 6| 2|
# ------- --- ----- -------
We now need to make a "reference" dataframe which contains only initial amount (i.e row_now = 1
)
ref = (df
.where(F.col('row_now') == 1)
.drop('row_now', 'value')
.withColumnRenamed('amt', 'init_amt')
)
# ------- --------
# |acc_key|init_amt|
# ------- --------
# | 126| 70|
# | 124| 20|
# | 127| 80|
# | 125| 40|
# ------- --------
Finally, join with original so we have staring point to apply lag
function
(df
.join(ref, ['acc_key'])
.withColumn('temp', F
.when(F.col('row_now') == 1, F.col('init_amt'))
.otherwise(F.lag('value').over(W.partitionBy('acc_key').orderBy('row_now')))
)
.withColumn('amt_c', F.sum('temp').over(W.partitionBy('acc_key').orderBy('row_now')))
.withColumn('lkp_rev', F.col('amt_c') F.col('value'))
.drop('init_amt', 'temp')
.show()
)
# ------- --- ----- ------- ----- -------
# |acc_key|amt|value|row_now|amt_c|lkp_rev|
# ------- --- ----- ------- ----- -------
# | 126| 70| 2| 1| 70| 72|
# | 126| 60| 3| 2| 72| 75|
# | 126| 50| 4| 3| 75| 79|
# | 124| 20| 7| 1| 20| 27|
# | 124| 10| 8| 2| 27| 35|
# | 127| 80| 1| 1| 80| 81|
# | 125| 40| 5| 1| 40| 45|
# | 125| 30| 6| 2| 45| 51|
# ------- --- ----- ------- ----- -------