I want to find the average of a column (hit) partitioned by id but filtering out rows based on the current rows value, in this case i want to filter out rows for each partition whose date is greater than the current rows date.
data = [
("456",'0','14-02-2022'),
("456",'1','13-02-2022'),
("123",'0','12-02-2022'),
("456",'0','11-02-2022'),
("123",'1','15-02-2022')]
cols = ["id","hit","date"]
df = spark.createDataFrame(data = data, schema = cols)
Expected outcome:
--- --- ---------- ---------
| id|hit| date| avg(hit)
--- --- ---------- ---------
|456| 0|14-02-2022| 1/2 (both dates of id 456 are lower than this date)
|456| 1|13-02-2022| 0 (the other date of id 456 is lower, so avg is 0)
|123| 0|12-02-2022| None (there is no record with a lower date of id 123)
|456| 0|11-02-2022| None (there is no record with a lower date of id 456)
|123| 1|15-02-2022| 0 (the other date of id 123 is lower, so avg is 0)
--- --- ---------- ---------
How can i achieve that?
CodePudding user response:
An avg
window function that looks at unbounded preceding and 1 row before current row will give you the required outcome.
# convert date field to date type
data_sdf = spark.sparkContext.parallelize(data_ls).toDF(['id', 'hit', 'dt']). \
withColumn('dt', func.to_date('dt', 'dd-MM-yyyy').cast('date'))
# --- --- ----------
# | id|hit| dt|
# --- --- ----------
# |456| 0|2022-02-14|
# |456| 1|2022-02-13|
# |123| 0|2022-02-12|
# |456| 0|2022-02-11|
# |123| 1|2022-02-15|
# --- --- ----------
data_sdf. \
withColumn('hit_avg',
func.avg('hit').over(wd.partitionBy('id').orderBy('dt').rowsBetween(-sys.maxsize, -1))
). \
show()
# --- --- ---------- -------
# | id|hit| dt|hit_avg|
# --- --- ---------- -------
# |456| 0|2022-02-11| null|
# |456| 1|2022-02-13| 0.0|
# |456| 0|2022-02-14| 0.5|
# |123| 0|2022-02-12| null|
# |123| 1|2022-02-15| 0.0|
# --- --- ---------- -------
CodePudding user response:
Here's another way (using lag to get the shifted average you require):
from pyspark.sql.window import Window
from pyspark.sql import functions as F
w = Window.partitionBy("id").orderBy(F.col("date").asc())
df = df.withColumn("avg", F.lag(
F.avg("hit").over(w)
).over(w))