Home > Net >  window function on a subset of data
window function on a subset of data

Time:06-28

I have a table like the below. I want to calculate an average of median but only for Q=2 and Q=3. I don't want to include other Qs but still preserve the data.

df = spark.createDataFrame([('2018-03-31',6,1),('2018-03-31',27,2),('2018-03-31',3,3),('2018-03-31',44,4),('2018-06-30',6,1),('2018-06-30',4,3),('2018-06-30',32,2),('2018-06-30',112,4),('2018-09-30',2,1),('2018-09-30',23,4),('2018-09-30',37,3),('2018-09-30',3,2)],['date','median','Q'])
 ---------- -------- --- 
|      date| median | Q |
 ---------- -------- --- 
|2018-03-31|       6|  1|
|2018-03-31|      27|  2|
|2018-03-31|       3|  3|
|2018-03-31|      44|  4|
|2018-06-30|       6|  1|
|2018-06-30|       4|  3|
|2018-06-30|      32|  2|
|2018-06-30|     112|  4|
|2018-09-30|       2|  1|
|2018-09-30|      23|  4|
|2018-09-30|      37|  3|
|2018-09-30|       3|  2|
 ---------- -------- --- 

Expected output:

 ---------- -------- --- ------------ 
|      date| median | Q |result      |
 ---------- -------- --- ------------ 
|2018-03-31|       6|  1|        null|
|2018-03-31|      27|  2|          15|
|2018-03-31|       3|  3|          15|
|2018-03-31|      44|  4|        null|
|2018-06-30|       6|  1|        null|
|2018-06-30|       4|  3|          18|
|2018-06-30|      32|  2|          18|
|2018-06-30|     112|  4|        null|
|2018-09-30|       2|  1|        null|
|2018-09-30|      23|  4|        null|
|2018-09-30|      37|  3|          20|
|2018-09-30|       3|  2|          20|
 ---------- -------- --- ------------ 

OR

 ---------- -------- --- ------------ 
|      date| median | Q |result      |
 ---------- -------- --- ------------ 
|2018-03-31|       6|  1|          15|
|2018-03-31|      27|  2|          15|
|2018-03-31|       3|  3|          15|
|2018-03-31|      44|  4|          15|
|2018-06-30|       6|  1|          18|
|2018-06-30|       4|  3|          18|
|2018-06-30|      32|  2|          18|
|2018-06-30|     112|  4|          18|
|2018-09-30|       2|  1|          20|
|2018-09-30|      23|  4|          20|
|2018-09-30|      37|  3|          20|
|2018-09-30|       3|  2|          20|
 ---------- -------- --- ------------ 

I tried the following code but when I include the where statement it drops Q=1 and Q=4.

window = (    
    Window    
    .partitionBy("date")    
    .orderBy("date")    
)   
     
df_avg = (    
    df    
    .where(    
        (F.col("Q") == 2) |    
        (F.col("Q") == 3)    
    )    
    .withColumn("result", F.avg("median").over(window))    
)

CodePudding user response:

For both of your expected output, you can use conditional aggregation, use avg with when (otherwise).

If you want the 1st expected output.

window = (    
    Window    
    .partitionBy("date", F.col("Q").isin([2, 3]))
)   
     
df_avg = (    
    df.withColumn("result", F.when(F.col("Q").isin([2, 3]), F.avg("median").over(window)))
)

For the 2nd expected output.

window = (    
    Window    
    .partitionBy("date")       
) 

df_avg = (
    df.withColumn("result", F.avg(F.when(F.col("Q").isin([2, 3]), F.col("median"))).over(window))
)

CodePudding user response:

Alternatively, since you are really aggregating a (small?) subset, replace window with auto-join:

>>> df_avg = df.where(col("Q").isin([2,3])).groupBy("date","Q").agg(avg("median").alias("result"))
>>> df_result = df.join(df_avg,["date","Q"],"left")

Might turn out to be faster than using window.

  • Related