Let us use the following dummy data:
df = spark.createDataFrame([(1,2),(1,3),(1,40),(1,0),(2,3),(2,1),(2,4),(3,2),(3,4)],['a','b'])
df.show()
--- ---
| a| b|
--- ---
| 1| 2|
| 1| 3|
| 1| 40|
| 1| 0|
| 2| 3|
| 2| 1|
| 2| 4|
| 3| 2|
| 3| 4|
--- ---
- How to filter out the data groups that do not have average(b) > 6
Expected output:
--- ---
| a| b|
--- ---
| 1| 2|
| 1| 3|
| 1| 40|
| 1| 0|
--- ---
How I am achieving it:
df_filter = df.groupby('a').agg(F.mean(F.col('b')).alias("avg"))
df_filter = df_filter.filter(F.col('avg') > 6.)
df.join(df_filter,'a','inner').drop('avg').show()
Problem:
- The shuffle happens twice. Once for computing the df_filter and the other time for the join.
df_filter = df.groupby('a').agg(F.mean(F.col('b')).alias("avg"))
df_filter = df_filter.filter(F.col('avg') > 6.)
df.join(df_filter,'a','inner').drop('avg').explain()
== Physical Plan ==
*(5) Project [a#175L, b#176L]
- *(5) SortMergeJoin [a#175L], [a#222L], Inner
:- *(2) Sort [a#175L ASC NULLS FIRST], false, 0
: - Exchange hashpartitioning(a#175L, 200), ENSURE_REQUIREMENTS, [plan_id=919]
: - *(1) Filter isnotnull(a#175L)
: - *(1) Scan ExistingRDD[a#175L,b#176L]
- *(4) Sort [a#222L ASC NULLS FIRST], false, 0
- *(4) Project [a#222L]
- *(4) Filter (isnotnull(avg#219) AND (avg#219 > 6.0))
- *(4) HashAggregate(keys=[a#222L], functions=[avg(b#223L)])
- Exchange hashpartitioning(a#222L, 200), ENSURE_REQUIREMENTS, [plan_id=925]
- *(3) HashAggregate(keys=[a#222L], functions=[partial_avg(b#223L)])
- *(3) Filter isnotnull(a#222L)
- *(3) Scan ExistingRDD[a#222L,b#223L]
If I think about it I should just shuffle the data once on the key a
and then no more shuffles are needed since every partition would be self sufficient.
Question: In general, What is the efficient way to exclude the data groups that do not satisfy a group-dependent filter?
CodePudding user response:
You can use Window
functionality instead of doing groupBy join,
out = df.withColumn("avg", avg(col("b")).over(Window.partitionBy("a")))\
.where("avg>6").drop("avg")
out.explain()
out.show()
- Project [a#0L, b#1L]
- Filter (isnotnull(avg#5) AND (avg#5 > 6.0))
- Window [avg(b#1L) windowspecdefinition(a#0L, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS avg#5], [a#0L]
- Sort [a#0L ASC NULLS FIRST], false, 0
- Exchange hashpartitioning(a#0L, 200), ENSURE_REQUIREMENTS, [plan_id=16]
- Scan ExistingRDD[a#0L,b#1L]
--- ---
| a| b|
--- ---
| 1| 2|
| 1| 3|
| 1| 40|
| 1| 0|
--- ---