I have the below data and final_column
is the exact output what I am trying to get. I am trying to do cumulative sum of flag
and want to rest if flag
is 0 then set value to 0 as below data
cola date flag final_column
a 2021-10-01 0 0
a 2021-10-02 1 1
a 2021-10-03 1 2
a 2021-10-04 0 0
a 2021-10-05 0 0
a 2021-10-06 0 0
a 2021-10-07 1 1
a 2021-10-08 1 2
a 2021-10-09 1 3
a 2021-10-10 0 0
b 2021-10-01 0 0
b 2021-10-02 1 1
b 2021-10-03 1 2
b 2021-10-04 0 0
b 2021-10-05 0 0
b 2021-10-06 1 1
b 2021-10-07 1 2
b 2021-10-08 1 3
b 2021-10-09 1 4
b 2021-10-10 0 0
I have tried like
import org.apache.spark.sql.functions._
df.withColumn("final_column",expr("sum(flag) over(partition by cola order date asc)"))
I have tried to add condition like case when flag = 0 then 0 else 1 end
inside sum function but not working.
CodePudding user response:
You can define a column group
using conditional sum on flag
, then using row_number
with a Window partitioned by cola
and group
gives the result you want:
import org.apache.spark.sql.expressions.Window
val result = df.withColumn(
"group",
sum(when(col("flag") === 0, 1).otherwise(0)).over(Window.partitionBy("cola").orderBy("date"))
).withColumn(
"final_column",
row_number().over(Window.partitionBy("cola", "group").orderBy("date")) - 1
).drop("group")
result.show
// ---- ----- ---- ------------
//|cola| date|flag|final_column|
// ---- ----- ---- ------------
//| b|44201| 0| 0|
//| b|44202| 1| 1|
//| b|44203| 1| 2|
//| b|44204| 0| 0|
//| b|44205| 0| 0|
//| b|44206| 1| 1|
//| b|44207| 1| 2|
//| b|44208| 1| 3|
//| b|44209| 1| 4|
//| b|44210| 0| 0|
//| a|44201| 0| 0|
//| a|44202| 1| 1|
//| a|44203| 1| 2|
//| a|44204| 0| 0|
//| a|44205| 0| 0|
//| a|44206| 0| 0|
//| a|44207| 1| 1|
//| a|44208| 1| 2|
//| a|44209| 1| 3|
//| a|44210| 0| 0|
// ---- ----- ---- ------------
row_number() - 1
in this case is just equivalent to sum(col("flag"))
as flag values are always 0 or 1. So the above final_column
can also be written as:
.withColumn(
"final_column",
sum(col("flag")).over(Window.partitionBy("cola", "group").orderBy("date"))
)