I have a dataframe ordered by date. Each row contains a "flag" column (values are 1 or 0). I want to find the sequences of 3 (or more) consecutive rows with the "flag" value equal to 1. The objective is to reset the "flag" value to 0 if it is not part of a 3 or more consecutive elements sequence with the :flag" being 1.
This is an example of the data:
date | flag |
---|---|
01-01-2022 | 1 |
02-01-2022 | 1 |
03-01-2022 | 1 |
04-01-2022 | 1 |
05-01-2022 | 0 |
06-01-2022 | 0 |
07-01-2022 | 1 |
08-01-2022 | 1 |
09-01-2022 | 0 |
10-01-2022 | 1 |
We have to keep value 1 only for the first 4 rows, as they constitute a sequence of four rows (more than 3) with a 1 in the flag. The desired output should be:
date | flag |
---|---|
01-01-2022 | 1 |
02-01-2022 | 1 |
03-01-2022 | 1 |
04-01-2022 | 1 |
05-01-2022 | 0 |
06-01-2022 | 0 |
07-01-2022 | 0 |
08-01-2022 | 0 |
09-01-2022 | 0 |
10-01-2022 | 0 |
I thought that maybe it makes sense to use the lag
function based on the previous element but not sure about how efficient it is in PySpark.
CodePudding user response:
You will need to use several window functions. Counting flags in 3 different windows: -2:0
, -1:1
, 0:-2
. If the sum
of at least one from these is 3
, then you have 3 consecutive 1's.
In the following script I assumed tat your dates are stored in the string data type, so I have used a column expression date
to read true dates from strings. If your dates are not in string format, you should not use that line.
Input:
from pyspark.sql import functions as F, Window as W
df = spark.createDataFrame(
[('01-01-2022', 1),
('02-01-2022', 1),
('03-01-2022', 1),
('04-01-2022', 1),
('05-01-2022', 0),
('06-01-2022', 0),
('07-01-2022', 1),
('08-01-2022', 1),
('09-01-2022', 0),
('10-01-2022', 1)],
['date', 'flag'])
Script:
date = F.to_date('date', 'dd-MM-yyyy')
cond1 = F.sum('flag').over(W.orderBy(date).rowsBetween(-2, 0)) == 3
cond2 = F.sum('flag').over(W.orderBy(date).rowsBetween(-1, 1)) == 3
cond3 = F.sum('flag').over(W.orderBy(date).rowsBetween(0, 2)) == 3
df = df.withColumn('flag', (cond1 | cond2 | cond3).cast('long'))
df.show()
# ---------- ----
# | date|flag|
# ---------- ----
# |01-01-2022| 1|
# |02-01-2022| 1|
# |03-01-2022| 1|
# |04-01-2022| 1|
# |05-01-2022| 0|
# |06-01-2022| 0|
# |07-01-2022| 0|
# |08-01-2022| 0|
# |09-01-2022| 0|
# |10-01-2022| 0|
# ---------- ----