Home > database >  How to do cumulative sum based on conditions in spark scala
How to do cumulative sum based on conditions in spark scala

Time:12-22

I have the below data and final_column is the exact output what I am trying to get. I am trying to do cumulative sum of flag and want to rest if flag is 0 then set value to 0 as below data

cola date       flag final_column
a   2021-10-01  0   0
a   2021-10-02  1   1
a   2021-10-03  1   2
a   2021-10-04  0   0
a   2021-10-05  0   0
a   2021-10-06  0   0
a   2021-10-07  1   1
a   2021-10-08  1   2
a   2021-10-09  1   3
a   2021-10-10  0   0
b   2021-10-01  0   0
b   2021-10-02  1   1
b   2021-10-03  1   2
b   2021-10-04  0   0
b   2021-10-05  0   0
b   2021-10-06  1   1
b   2021-10-07  1   2
b   2021-10-08  1   3
b   2021-10-09  1   4
b   2021-10-10  0   0

I have tried like

import org.apache.spark.sql.functions._

df.withColumn("final_column",expr("sum(flag) over(partition by cola order date asc)"))

I have tried to add condition like case when flag = 0 then 0 else 1 end inside sum function but not working.

CodePudding user response:

You can define a column group using conditional sum on flag, then using row_number with a Window partitioned by cola and group gives the result you want:

import org.apache.spark.sql.expressions.Window

val result = df.withColumn(
    "group",
    sum(when(col("flag") === 0, 1).otherwise(0)).over(Window.partitionBy("cola").orderBy("date"))
).withColumn(
    "final_column",
    row_number().over(Window.partitionBy("cola", "group").orderBy("date")) - 1
).drop("group")

result.show

// ---- ----- ---- ------------ 
//|cola| date|flag|final_column|
// ---- ----- ---- ------------ 
//|   b|44201|   0|           0|
//|   b|44202|   1|           1|
//|   b|44203|   1|           2|
//|   b|44204|   0|           0|
//|   b|44205|   0|           0|
//|   b|44206|   1|           1|
//|   b|44207|   1|           2|
//|   b|44208|   1|           3|
//|   b|44209|   1|           4|
//|   b|44210|   0|           0|
//|   a|44201|   0|           0|
//|   a|44202|   1|           1|
//|   a|44203|   1|           2|
//|   a|44204|   0|           0|
//|   a|44205|   0|           0|
//|   a|44206|   0|           0|
//|   a|44207|   1|           1|
//|   a|44208|   1|           2|
//|   a|44209|   1|           3|
//|   a|44210|   0|           0|
// ---- ----- ---- ------------ 

row_number() - 1 in this case is just equivalent to sum(col("flag")) as flag values are always 0 or 1. So the above final_column can also be written as:

.withColumn(
    "final_column",
    sum(col("flag")).over(Window.partitionBy("cola", "group").orderBy("date"))
)
  • Related