Home > OS >  Find a sequence of consecutive rows with the same value in a PySpark dataframe
Find a sequence of consecutive rows with the same value in a PySpark dataframe

Time:10-18

I have a dataframe ordered by date. Each row contains a "flag" column (values are 1 or 0). I want to find the sequences of 3 (or more) consecutive rows with the "flag" value equal to 1. The objective is to reset the "flag" value to 0 if it is not part of a 3 or more consecutive elements sequence with the :flag" being 1.

This is an example of the data:

date flag
01-01-2022 1
02-01-2022 1
03-01-2022 1
04-01-2022 1
05-01-2022 0
06-01-2022 0
07-01-2022 1
08-01-2022 1
09-01-2022 0
10-01-2022 1

We have to keep value 1 only for the first 4 rows, as they constitute a sequence of four rows (more than 3) with a 1 in the flag. The desired output should be:

date flag
01-01-2022 1
02-01-2022 1
03-01-2022 1
04-01-2022 1
05-01-2022 0
06-01-2022 0
07-01-2022 0
08-01-2022 0
09-01-2022 0
10-01-2022 0

I thought that maybe it makes sense to use the lag function based on the previous element but not sure about how efficient it is in PySpark.

CodePudding user response:

You will need to use several window functions. Counting flags in 3 different windows: -2:0, -1:1, 0:-2. If the sum of at least one from these is 3, then you have 3 consecutive 1's.

In the following script I assumed tat your dates are stored in the string data type, so I have used a column expression date to read true dates from strings. If your dates are not in string format, you should not use that line.

Input:

from pyspark.sql import functions as F, Window as W
df = spark.createDataFrame(
    [('01-01-2022', 1),
     ('02-01-2022', 1),
     ('03-01-2022', 1),
     ('04-01-2022', 1),
     ('05-01-2022', 0),
     ('06-01-2022', 0),
     ('07-01-2022', 1),
     ('08-01-2022', 1),
     ('09-01-2022', 0),
     ('10-01-2022', 1)],
    ['date', 'flag'])

Script:

date = F.to_date('date', 'dd-MM-yyyy')
cond1 = F.sum('flag').over(W.orderBy(date).rowsBetween(-2, 0)) == 3
cond2 = F.sum('flag').over(W.orderBy(date).rowsBetween(-1, 1)) == 3
cond3 = F.sum('flag').over(W.orderBy(date).rowsBetween(0, 2)) == 3
df = df.withColumn('flag', (cond1 | cond2 | cond3).cast('long'))

df.show()
#  ---------- ---- 
# |      date|flag|
#  ---------- ---- 
# |01-01-2022|   1|
# |02-01-2022|   1|
# |03-01-2022|   1|
# |04-01-2022|   1|
# |05-01-2022|   0|
# |06-01-2022|   0|
# |07-01-2022|   0|
# |08-01-2022|   0|
# |09-01-2022|   0|
# |10-01-2022|   0|
#  ---------- ---- 
  • Related