Home > database >  Cumulative sum of n values in pyspark dataframe
Cumulative sum of n values in pyspark dataframe

Time:06-25

I have this current table in databricks:

 -------------------- ------------------- ------------ ------------ 
|            card_uid|               date|   local    |   amount   |
 -------------------- ------------------- ------------ ------------ 
|card_001H4Mw1Ha0M...|2016-05-04 17:54:30|         100|        8.99|
|card_0026uGZQwZQd...|2016-05-06 12:16:18|           0|        8.99|
|card_0026uGZQwZQd...|2016-07-06 12:17:57|         100|        8.99|
|card_003STfrgB8SZ...|2016-12-04 10:05:21|         100|        8.99|
|card_005gBxyiDc6b...|2016-09-10 18:58:25|         100|        8.99|
|card_005gBxyiAc6b...|2016-11-13 11:18:29|         100|        8.99|
|card_003STfrgC8SZ...|2016-12-05 12:05:21|         100|        8.99|
|card_002gBxyiSc6b...|2016-09-14 11:58:25|         100|        8.99|
|card_005gBxyiZc6b...|2016-11-15 15:18:29|         100|        8.99|

I would like to create a third column named SUM, where each row value is based on the condition: if amount >= 8.99 and the sum of the previous 3 local values == 300 then insert the sum of the previous 3 amounts, else 256. This is the final column based on this condition:

 ------------ 
|     sum    |
 ------------ 
|     256    |
|     256    |
|     256    |
|     256    |
|     256    |
|    26.97   |
|    26.97   |
|    26.97   |
|    26.97   |

How can I reproduce it in a pyspark dataframe? In a normal pandas dataframe I would loop with a condition like this:

for i in range(len(df)):
  if i < 3: continue
  data.loc[index-3:index].sum = ...

but how can I reproduce it using the window function?

CodePudding user response:

Sum over a Window bounded with rowsBetween(-3, -1) (3 previous rows) and when expression should do the job:

from pyspark.sql import Window, functions as F

w = Window.orderBy("date").rowsBetween(-3, -1)

df1 = df.withColumn(
    "SUM",
    F.when(
        (F.col("amount") >= 8.99) & (F.sum("local").over(w) == 300),
        F.sum("amount").over(w)
    ).otherwise(256)

)

df1.show()

#  ----------------- ------------------- ----- ------ ----- 
# |         card_uid|               date|local|amount|  SUM|
#  ----------------- ------------------- ----- ------ ----- 
# |card_001H4Mw1Ha0M|2016-05-04 17:54:30|  100|  8.99|256.0|
# |card_0026uGZQwZQd|2016-05-06 12:16:18|    0|  8.99|256.0|
# |card_0026uGZQwZQd|2016-07-06 12:17:57|  100|  8.99|256.0|
# |card_005gBxyiDc6b|2016-09-10 18:58:25|  100|  8.99|256.0|
# |card_002gBxyiSc6b|2016-09-14 11:58:25|  100|  8.99|256.0|
# |card_005gBxyiAc6b|2016-11-13 11:18:29|  100|  8.99|26.97|
# |card_005gBxyiZc6b|2016-11-15 15:18:29|  100|  8.99|26.97|
# |card_003STfrgB8SZ|2016-12-04 10:05:21|  100|  8.99|26.97|
# |card_003STfrgC8SZ|2016-12-05 12:05:21|  100|  8.99|26.97|
#  ----------------- ------------------- ----- ------ ----- 
  • Related