Home > Software engineering >  Create interdependent column in pyspark
Create interdependent column in pyspark

Time:10-23

I have a data frame in pyspark like below

df = spark.createDataFrame([
(124,10,8),
(124,20,7),
(125,30,6),
(125,40,5),
(126,50,4),
(126,60,3),
(126,70,2),
(127,80,1)],("ACC_KEY", "AMT", "value"))

df.show()

 ------- --- ----- 
|ACC_KEY|AMT|value|
 ------- --- ----- 
|    126| 70|    2|
|    126| 60|    3|
|    126| 50|    4|
|    124| 20|    7|
|    124| 10|    8|
|    127| 80|    1|
|    125| 40|    5|
|    125| 30|    6|
 ------- --- ----- 

Expected result

 ------- --- ----- ------- ----- ------- 
|ACC_KEY|AMT|value|row_now|amt_c|lkp_rev|
 ------- --- ----- ------- ----- ------- 
|    126| 70|    2|      1|   70|     72|
|    126| 60|    3|      2|   72|     75|
|    126| 50|    4|      3|   75|     79|
|    124| 20|    7|      1|   20|     27|
|    124| 10|    8|      2|   27|     35|
|    127| 80|    1|      1|   80|     81|
|    125| 40|    5|      1|   40|     45|
|    125| 30|    6|      2|   45|     51|
 ------- --- ----- ------- ----- ------- 

Conditions

1) When row_number = 1 then amt_c column = column AMT
2) when row_number != 1 then It should be the lag of column lkp_rev   column value
3) lkp_rev column = amt_c column   value column

I have tried like below

import pyspark.sql.functions as f
from pyspark.sql import Window

# create row_number column
df1 = df.withColumn("row_now", f.row_number().over(Window.partitionBy("ACC_KEY").orderBy(f.col('AMT').desc())))

# amt_c column creation
df2 = df1.withColumn("amt_c", f.when(f.col("row_now") == 1, f.col("AMT")).otherwise(f.col("value")   f.col("AMT")))

How can I achieve what i want

CodePudding user response:

I figure that it'd be much easier if you separate all rows that have row_now = 1, and make it a as a "reference" dataframe or starting point for each acc_key.

First, adding row number so we can reuse later

df = df.withColumn('row_now', F.row_number().over(W.partitionBy('acc_key').orderBy(F.col('amt').desc())))
#  ------- --- ----- ------- 
# |acc_key|amt|value|row_now|
#  ------- --- ----- ------- 
# |    126| 70|    2|      1|
# |    126| 60|    3|      2|
# |    126| 50|    4|      3|
# |    124| 20|    7|      1|
# |    124| 10|    8|      2|
# |    127| 80|    1|      1|
# |    125| 40|    5|      1|
# |    125| 30|    6|      2|
#  ------- --- ----- ------- 

We now need to make a "reference" dataframe which contains only initial amount (i.e row_now = 1)

ref = (df
    .where(F.col('row_now') == 1)
    .drop('row_now', 'value')
    .withColumnRenamed('amt', 'init_amt')
)
#  ------- -------- 
# |acc_key|init_amt|
#  ------- -------- 
# |    126|      70|
# |    124|      20|
# |    127|      80|
# |    125|      40|
#  ------- -------- 

Finally, join with original so we have staring point to apply lag function

(df
    .join(ref, ['acc_key'])
    .withColumn('temp', F
        .when(F.col('row_now') == 1, F.col('init_amt'))
        .otherwise(F.lag('value').over(W.partitionBy('acc_key').orderBy('row_now')))
    )
    .withColumn('amt_c', F.sum('temp').over(W.partitionBy('acc_key').orderBy('row_now')))
    .withColumn('lkp_rev', F.col('amt_c')   F.col('value'))
    .drop('init_amt', 'temp')
    .show()
)

#  ------- --- ----- ------- ----- ------- 
# |acc_key|amt|value|row_now|amt_c|lkp_rev|
#  ------- --- ----- ------- ----- ------- 
# |    126| 70|    2|      1|   70|     72|
# |    126| 60|    3|      2|   72|     75|
# |    126| 50|    4|      3|   75|     79|
# |    124| 20|    7|      1|   20|     27|
# |    124| 10|    8|      2|   27|     35|
# |    127| 80|    1|      1|   80|     81|
# |    125| 40|    5|      1|   40|     45|
# |    125| 30|    6|      2|   45|     51|
#  ------- --- ----- ------- ----- ------- 
  • Related