Home > Software design >  How to calculate cumulative sum on a column based with condition threshold pyspark
How to calculate cumulative sum on a column based with condition threshold pyspark

Time:10-21

I have a dataframe shown below. I want to compute the cumulative sum for each id on V such that when the cumulative of the previous row is greater or equal to the threshold of 25, the cumulative sum will reset to current value as seen in the image below. I have tried writing a user-defined to operator on V but I received an error which says the it is not iterable. I tried you lag but I was not successful either. I need help!

df = sqlContext.createDataFrame(
[('Mark', 0.0), ('Mark', 1), ('Mark', 1),
('Mark', 1), ('Mark', 25), ('Mark', 1),
('Mark', 1),('Mark', 1),('Mark', 20),
('Mark', 1),('Mark', 1),('Mark', 1),
('Mark', 1),('Mark', 1),('John', 0),
('John', 1),('John', 1),('John', 1),
('John', 1),('John', 1),('John', 1),
('John', 1),('John', 9),('John', 1),
('John', 1),('John', 1),('John', 1),
('John', 1),('John', 1),('John', 1),
('John', 1),('John', 1),('John', 1),
('John', 7),('John', 1)], 
('id', "V"))

enter image description here

CodePudding user response:

Maybe there is a prettier faster way, this will work but isn't efficient. Windows are expensive (memory) to use. If you are going to use this in production be wary. If speed is required a custom udf used with a window might be a little faster than this to avoid the double windowing:

    val df = Seq(("Mark", 0), ("Mark", 1), ("Mark", 1),
    ("Mark", 1), ("Mark", 25), ("Mark", 1),
    ("Mark", 1),("Mark", 1),("Mark", 20),
    ("Mark", 1),("Mark", 1),("Mark", 1),
    ("Mark", 1),("Mark", 1),("John", 0),
    ("John", 1),("John", 1),("John", 1),
    ("John", 1),("John", 1),("John", 1),
    ("John", 1),("John", 9),("John", 1),
    ("John", 1),("John", 1),("John", 1),
    ("John", 1),("John", 1),("John", 1),
    ("John", 1),("John", 1),("John", 1),
    ("John", 7),("John", 1)).toDF("id","V")
    val windowSpecLag = Window.partitionBy("id").orderBy("id")
    val windowSpec = Window.partitionBy("id").orderBy("id").rowsBetween(Window.unboundedPreceding,Window.currentRow )
    // add a running sum to the window
    val divis = df.withColumn("sum",sum("V").over(windowSpec)).withColumn("divis",floor(col("sum")/25))
    // shift around the numbers so the math works as desired
    val lagged = divis.withColumn("clag", (lag("divis", 1, 0) over windowSpecLag) )
    //re-run running total on newly partitioned data
val windowSpecFixed = Window.partitionBy("id","clag").orderBy("id","clag").rowsBetween(Window.unboundedPreceding,Window.currentRow )
lagged.withColumn("runningTotalUnder25",sum("V").over(windowSpecDivis)).show(100)

If you wanted to do this efficiently, I'd probably try and reframe the problem so I could use group by. Or change the way the data is defined.

  • Related