I have a table like the below. I want to calculate an average of median but only for Q=2 and Q=3. I don't want to include other Qs but still preserve the data.
df = spark.createDataFrame([('2018-03-31',6,1),('2018-03-31',27,2),('2018-03-31',3,3),('2018-03-31',44,4),('2018-06-30',6,1),('2018-06-30',4,3),('2018-06-30',32,2),('2018-06-30',112,4),('2018-09-30',2,1),('2018-09-30',23,4),('2018-09-30',37,3),('2018-09-30',3,2)],['date','median','Q'])
---------- -------- ---
| date| median | Q |
---------- -------- ---
|2018-03-31| 6| 1|
|2018-03-31| 27| 2|
|2018-03-31| 3| 3|
|2018-03-31| 44| 4|
|2018-06-30| 6| 1|
|2018-06-30| 4| 3|
|2018-06-30| 32| 2|
|2018-06-30| 112| 4|
|2018-09-30| 2| 1|
|2018-09-30| 23| 4|
|2018-09-30| 37| 3|
|2018-09-30| 3| 2|
---------- -------- ---
Expected output:
---------- -------- --- ------------
| date| median | Q |result |
---------- -------- --- ------------
|2018-03-31| 6| 1| null|
|2018-03-31| 27| 2| 15|
|2018-03-31| 3| 3| 15|
|2018-03-31| 44| 4| null|
|2018-06-30| 6| 1| null|
|2018-06-30| 4| 3| 18|
|2018-06-30| 32| 2| 18|
|2018-06-30| 112| 4| null|
|2018-09-30| 2| 1| null|
|2018-09-30| 23| 4| null|
|2018-09-30| 37| 3| 20|
|2018-09-30| 3| 2| 20|
---------- -------- --- ------------
OR
---------- -------- --- ------------
| date| median | Q |result |
---------- -------- --- ------------
|2018-03-31| 6| 1| 15|
|2018-03-31| 27| 2| 15|
|2018-03-31| 3| 3| 15|
|2018-03-31| 44| 4| 15|
|2018-06-30| 6| 1| 18|
|2018-06-30| 4| 3| 18|
|2018-06-30| 32| 2| 18|
|2018-06-30| 112| 4| 18|
|2018-09-30| 2| 1| 20|
|2018-09-30| 23| 4| 20|
|2018-09-30| 37| 3| 20|
|2018-09-30| 3| 2| 20|
---------- -------- --- ------------
I tried the following code but when I include the where statement it drops Q=1 and Q=4.
window = (
Window
.partitionBy("date")
.orderBy("date")
)
df_avg = (
df
.where(
(F.col("Q") == 2) |
(F.col("Q") == 3)
)
.withColumn("result", F.avg("median").over(window))
)
CodePudding user response:
For both of your expected output, you can use conditional aggregation, use avg
with when
(otherwise
).
If you want the 1st expected output.
window = (
Window
.partitionBy("date", F.col("Q").isin([2, 3]))
)
df_avg = (
df.withColumn("result", F.when(F.col("Q").isin([2, 3]), F.avg("median").over(window)))
)
For the 2nd expected output.
window = (
Window
.partitionBy("date")
)
df_avg = (
df.withColumn("result", F.avg(F.when(F.col("Q").isin([2, 3]), F.col("median"))).over(window))
)
CodePudding user response:
Alternatively, since you are really aggregating a (small?) subset, replace window with auto-join:
>>> df_avg = df.where(col("Q").isin([2,3])).groupBy("date","Q").agg(avg("median").alias("result"))
>>> df_result = df.join(df_avg,["date","Q"],"left")
Might turn out to be faster than using window
.