Home > database >  Pyspark filter in Window function based on current rows value
Pyspark filter in Window function based on current rows value

Time:08-01

I want to find the average of a column (hit) partitioned by id but filtering out rows based on the current rows value, in this case i want to filter out rows for each partition whose date is greater than the current rows date.

data = [
("456",'0','14-02-2022'),
("456",'1','13-02-2022'),
("123",'0','12-02-2022'),
("456",'0','11-02-2022'),
("123",'1','15-02-2022')]
    
cols = ["id","hit","date"]
    
df = spark.createDataFrame(data = data, schema = cols)

Expected outcome:

 --- --- ---------- ---------                                                            
| id|hit|      date| avg(hit)
 --- --- ---------- ---------
|456|  0|14-02-2022| 1/2 (both dates of id 456 are lower than this date)
|456|  1|13-02-2022| 0 (the other date of id 456 is lower, so avg is 0)
|123|  0|12-02-2022| None (there is no record with a lower date of id 123)
|456|  0|11-02-2022| None (there is no record with a lower date of id 456)
|123|  1|15-02-2022| 0 (the other date of id 123 is lower, so avg is 0)
 --- --- ---------- ---------

How can i achieve that?

CodePudding user response:

An avg window function that looks at unbounded preceding and 1 row before current row will give you the required outcome.

# convert date field to date type
data_sdf = spark.sparkContext.parallelize(data_ls).toDF(['id', 'hit', 'dt']). \
    withColumn('dt', func.to_date('dt', 'dd-MM-yyyy').cast('date'))

#  --- --- ---------- 
# | id|hit|        dt|
#  --- --- ---------- 
# |456|  0|2022-02-14|
# |456|  1|2022-02-13|
# |123|  0|2022-02-12|
# |456|  0|2022-02-11|
# |123|  1|2022-02-15|
#  --- --- ---------- 

data_sdf. \
    withColumn('hit_avg',
               func.avg('hit').over(wd.partitionBy('id').orderBy('dt').rowsBetween(-sys.maxsize, -1))
               ). \
    show()

#  --- --- ---------- ------- 
# | id|hit|        dt|hit_avg|
#  --- --- ---------- ------- 
# |456|  0|2022-02-11|   null|
# |456|  1|2022-02-13|    0.0|
# |456|  0|2022-02-14|    0.5|
# |123|  0|2022-02-12|   null|
# |123|  1|2022-02-15|    0.0|
#  --- --- ---------- ------- 

CodePudding user response:

Here's another way (using lag to get the shifted average you require):

from pyspark.sql.window import Window
from pyspark.sql import functions as F

w = Window.partitionBy("id").orderBy(F.col("date").asc())
df = df.withColumn("avg", F.lag(
    F.avg("hit").over(w)
).over(w))
  • Related