Home > Blockchain >  PySpark window function - within n months from current row
PySpark window function - within n months from current row

Time:05-08

I want to remove all rows within x months of current row (before and after based on date), when the current row is equal to 1.

E.g. given this PySpark df:

id date target
a "2020-01-01" 0
a "2020-02-01" 0
a "2020-03-01" 0
a "2020-04-01" 1
a "2020-05-01" 0
a "2020-06-01" 0
a "2020-07-01" 0
a "2020-08-01" 0
a "2020-09-01" 0
a "2020-10-01" 1
a "2020-11-01" 0
b "2020-01-01" 0
b "2020-02-01" 0
b "2020-03-01" 0
b "2020-05-01" 1

(Notice, April month does not exit for id b)

If using an x value of 2, the resulting df would be:

id date target
a "2020-01-01" 0
a "2020-04-01" 1
a "2020-07-01" 0
a "2020-10-01" 1
b "2020-01-01" 0
b "2020-02-01" 0
b "2020-05-01" 1

I am able to remove xth row before and after row of interest using the code from below, but I want to remove all rows between current row and x both ways based on date.

window = 2
windowSpec = Window.partitionBy("id").orderBy(['id','date'])
    
df= df.withColumn("lagvalue", lag('target', window).over(windowSpec))    
df= df.withColumn("leadvalue", lead('target', window).over(windowSpec))
df= df.where(col("lagvalue") == 0 & col("leadvalue") == 0)

CodePudding user response:

In your case, rangeBetween can be very useful. It pays attention to the values and takes only the values which fall into the range. E.g. rangeBetween(-2, 2) would take all the values from 2 below to 2 above. As rangeBetween does not work with dates (or strings), I translated them into integers using months_between.

from pyspark.sql import functions as F, Window
df = spark.createDataFrame(
    [('a', '2020-01-01', 0),
     ('a', '2020-02-01', 0),
     ('a', '2020-03-01', 0),
     ('a', '2020-04-01', 1),
     ('a', '2020-05-01', 0),
     ('a', '2020-06-01', 0),
     ('a', '2020-07-01', 0),
     ('a', '2020-08-01', 0),
     ('a', '2020-09-01', 0),
     ('a', '2020-10-01', 1),
     ('a', '2020-11-01', 0),
     ('b', '2020-01-01', 0),
     ('b', '2020-02-01', 0),
     ('b', '2020-03-01', 0),
     ('b', '2020-05-01', 1)],
    ['id', 'date', 'target']
)
window = 2
windowSpec = Window.partitionBy('id').orderBy(F.months_between('date', F.lit('1970-01-01'))).rangeBetween(-window, window)
df = df.withColumn('to_remove', F.sum('target').over(windowSpec) - F.col('target'))
df = df.where(F.col('to_remove') == 0).drop('to_remove')
df.show()
#  --- ---------- ------ 
# | id|      date|target|
#  --- ---------- ------ 
# |  a|2020-01-01|     0|
# |  a|2020-04-01|     1|
# |  a|2020-07-01|     0|
# |  a|2020-10-01|     1|
# |  b|2020-01-01|     0|
# |  b|2020-02-01|     0|
# |  b|2020-05-01|     1|
#  --- ---------- ------ 
  • Related