I want to remove all rows within x months of current row (before and after based on date), when the current row is equal to 1.
E.g. given this PySpark df:
id | date | target |
---|---|---|
a | "2020-01-01" | 0 |
a | "2020-02-01" | 0 |
a | "2020-03-01" | 0 |
a | "2020-04-01" | 1 |
a | "2020-05-01" | 0 |
a | "2020-06-01" | 0 |
a | "2020-07-01" | 0 |
a | "2020-08-01" | 0 |
a | "2020-09-01" | 0 |
a | "2020-10-01" | 1 |
a | "2020-11-01" | 0 |
b | "2020-01-01" | 0 |
b | "2020-02-01" | 0 |
b | "2020-03-01" | 0 |
b | "2020-05-01" | 1 |
(Notice, April month does not exit for id b)
If using an x value of 2, the resulting df would be:
id | date | target |
---|---|---|
a | "2020-01-01" | 0 |
a | "2020-04-01" | 1 |
a | "2020-07-01" | 0 |
a | "2020-10-01" | 1 |
b | "2020-01-01" | 0 |
b | "2020-02-01" | 0 |
b | "2020-05-01" | 1 |
I am able to remove xth row before and after row of interest using the code from below, but I want to remove all rows between current row and x both ways based on date.
window = 2
windowSpec = Window.partitionBy("id").orderBy(['id','date'])
df= df.withColumn("lagvalue", lag('target', window).over(windowSpec))
df= df.withColumn("leadvalue", lead('target', window).over(windowSpec))
df= df.where(col("lagvalue") == 0 & col("leadvalue") == 0)
CodePudding user response:
In your case, rangeBetween
can be very useful. It pays attention to the values and takes only the values which fall into the range. E.g. rangeBetween(-2, 2)
would take all the values from 2 below to 2 above. As rangeBetween
does not work with dates (or strings), I translated them into integers using months_between
.
from pyspark.sql import functions as F, Window
df = spark.createDataFrame(
[('a', '2020-01-01', 0),
('a', '2020-02-01', 0),
('a', '2020-03-01', 0),
('a', '2020-04-01', 1),
('a', '2020-05-01', 0),
('a', '2020-06-01', 0),
('a', '2020-07-01', 0),
('a', '2020-08-01', 0),
('a', '2020-09-01', 0),
('a', '2020-10-01', 1),
('a', '2020-11-01', 0),
('b', '2020-01-01', 0),
('b', '2020-02-01', 0),
('b', '2020-03-01', 0),
('b', '2020-05-01', 1)],
['id', 'date', 'target']
)
window = 2
windowSpec = Window.partitionBy('id').orderBy(F.months_between('date', F.lit('1970-01-01'))).rangeBetween(-window, window)
df = df.withColumn('to_remove', F.sum('target').over(windowSpec) - F.col('target'))
df = df.where(F.col('to_remove') == 0).drop('to_remove')
df.show()
# --- ---------- ------
# | id| date|target|
# --- ---------- ------
# | a|2020-01-01| 0|
# | a|2020-04-01| 1|
# | a|2020-07-01| 0|
# | a|2020-10-01| 1|
# | b|2020-01-01| 0|
# | b|2020-02-01| 0|
# | b|2020-05-01| 1|
# --- ---------- ------