Home > Mobile >  How to use pyspark to efficiently keep only those groupbs from dataframe that satisfy a certain grou
How to use pyspark to efficiently keep only those groupbs from dataframe that satisfy a certain grou

Time:02-05

Let us use the following dummy data:

df = spark.createDataFrame([(1,2),(1,3),(1,40),(1,0),(2,3),(2,1),(2,4),(3,2),(3,4)],['a','b'])
df.show()
 --- --- 
|  a|  b|
 --- --- 
|  1|  2|
|  1|  3|
|  1| 40|
|  1|  0|
|  2|  3|
|  2|  1|
|  2|  4|
|  3|  2|
|  3|  4|
 --- --- 
  1. How to filter out the data groups that do not have average(b) > 6

Expected output:

 --- --- 
|  a|  b|
 --- --- 
|  1|  2|
|  1|  3|
|  1| 40|
|  1|  0|
 --- --- 

How I am achieving it:

df_filter = df.groupby('a').agg(F.mean(F.col('b')).alias("avg"))

df_filter = df_filter.filter(F.col('avg') > 6.)

df.join(df_filter,'a','inner').drop('avg').show()

Problem:

  1. The shuffle happens twice. Once for computing the df_filter and the other time for the join.
df_filter = df.groupby('a').agg(F.mean(F.col('b')).alias("avg"))

df_filter = df_filter.filter(F.col('avg') > 6.)

df.join(df_filter,'a','inner').drop('avg').explain()
== Physical Plan ==
*(5) Project [a#175L, b#176L]
 - *(5) SortMergeJoin [a#175L], [a#222L], Inner
   :- *(2) Sort [a#175L ASC NULLS FIRST], false, 0
   :   - Exchange hashpartitioning(a#175L, 200), ENSURE_REQUIREMENTS, [plan_id=919]
   :      - *(1) Filter isnotnull(a#175L)
   :         - *(1) Scan ExistingRDD[a#175L,b#176L]
    - *(4) Sort [a#222L ASC NULLS FIRST], false, 0
       - *(4) Project [a#222L]
          - *(4) Filter (isnotnull(avg#219) AND (avg#219 > 6.0))
             - *(4) HashAggregate(keys=[a#222L], functions=[avg(b#223L)])
                - Exchange hashpartitioning(a#222L, 200), ENSURE_REQUIREMENTS, [plan_id=925]
                   - *(3) HashAggregate(keys=[a#222L], functions=[partial_avg(b#223L)])
                      - *(3) Filter isnotnull(a#222L)
                         - *(3) Scan ExistingRDD[a#222L,b#223L]

If I think about it I should just shuffle the data once on the key a and then no more shuffles are needed since every partition would be self sufficient.

Question: In general, What is the efficient way to exclude the data groups that do not satisfy a group-dependent filter?

CodePudding user response:

You can use Window functionality instead of doing groupBy join,

out = df.withColumn("avg", avg(col("b")).over(Window.partitionBy("a")))\
    .where("avg>6").drop("avg")

out.explain()
out.show()

 - Project [a#0L, b#1L]
    - Filter (isnotnull(avg#5) AND (avg#5 > 6.0))
       - Window [avg(b#1L) windowspecdefinition(a#0L, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS avg#5], [a#0L]
          - Sort [a#0L ASC NULLS FIRST], false, 0
             - Exchange hashpartitioning(a#0L, 200), ENSURE_REQUIREMENTS, [plan_id=16]
                - Scan ExistingRDD[a#0L,b#1L]


 --- --- 
|  a|  b|
 --- --- 
|  1|  2|
|  1|  3|
|  1| 40|
|  1|  0|
 --- --- 
  • Related