I am performing an aggregation on the following dataframe to get a list of advertisers with array of brands
------------ ------
|advertiser |brand |
------------ ------
|Advertiser 1|Brand1|
|Advertiser 1|Brand2|
|Advertiser 2|Brand3|
|Advertiser 2|Brand4|
|Advertiser 3|Brand5|
|Advertiser 3|Brand6|
------------ ------
Here is my code:
import org.apache.spark.sql.functions.collect_list
df2
.groupBy("advertiser")
.agg(collect_list("brand").as("brands"))
That gives me the following dataframe:
------------ ----------------
|advertiser |brands |
------------ ----------------
|Advertiser 1|[Brand1, Brand2]|
|Advertiser 2|[Brand3, Brand4]|
|Advertiser 3|[Brand5, Brand6]|
------------ ----------------
During the aggregation, I want to filter the list of brands with the following table of brands :
------ ------------
|brand |brand name |
------ ------------
|Brand1|Brand_name_1|
|Brand3|Brand_name_3|
------ ------------
In order to achieve:
------------ --------
|advertiser |brands |
------------ --------
|Advertiser 1|[Brand1]|
|Advertiser 2|[Brand3]|
|Advertiser 3|null |
------------ --------
CodePudding user response:
I see two solutions for your question, that I will call Collect Solution and Join Solution
Collect solution
If you can collect your brands
dataframe, you can use this collected collection to keep only right brands when performing collect_list
, then flatten
your array and replace empty array by null
as follow:
import org.apache.spark.sql.functions.{array, col, collect_list, flatten, size, when}
val filteredBrands = brands.select("brand").collect().map(_.getString(0))
val finalDataframe = df2
.groupBy("advertiser")
.agg(collect_list(when(col("brand").isin(filteredBrands: _*), array(col("brand"))).otherwise(array())).as("brands"))
.withColumn("brands", flatten(col("brands")))
.withColumn("brands", when(size(col("brands")).equalTo(0), null).otherwise(col("brands")))
Join solution
If your brands
dataframe doesn't fit in memory, you can first left join df2
with brands
to have a column containing brand if the brand is in brands
dataframe, else null
, then do your group by, and finally replace empty array due to advertisers without brands you want to filter by null
:
import org.apache.spark.sql.functions.{col, collect_list}
val finalDataframe = df2
.join(brands.select(col("brand").as("filtered_brand")), col("filtered_brand") === col("brand"), "left_outer")
.groupBy("advertiser").agg(collect_list(col("filtered_brand")).as("brands"))
.withColumn("brands", when(size(col("brands")).equalTo(0), null).otherwise(col("brands")))
Details
So if we start with a df2
dataframe as follow:
------------ ------
|advertiser |brand |
------------ ------
|Advertiser 1|Brand1|
|Advertiser 1|Brand2|
|Advertiser 2|Brand3|
|Advertiser 2|Brand4|
|Advertiser 3|Brand5|
|Advertiser 3|Brand6|
------------ ------
And a brands
dataframe as follow:
------ ------------
|brand |brand name |
------ ------------
|Brand1|Brand_name_1|
|Brand3|Brand_name_3|
------ ------------
After the first left outer join between df2
and brands
dataframes (first line), you get the following dataframe:
------------ ------ --------------
|advertiser |brand |filtered_brand|
------------ ------ --------------
|Advertiser 1|Brand1|Brand1 |
|Advertiser 1|Brand2|null |
|Advertiser 2|Brand3|Brand3 |
|Advertiser 2|Brand4|null |
|Advertiser 3|Brand5|null |
|Advertiser 3|Brand6|null |
------------ ------ --------------
When you group this dataframe by advertiser, collecting list of filtered brands, you get the following dataframe:
------------ --------
|advertiser |brands |
------------ --------
|Advertiser 2|[Brand3]|
|Advertiser 3|[] |
|Advertiser 1|[Brand1]|
------------ --------
And finally, when you apply last line replacing empty array with null, you get your expected result:
------------ --------
|advertiser |brands |
------------ --------
|Advertiser 2|[Brand3]|
|Advertiser 3|null |
|Advertiser 1|[Brand1]|
------------ --------
Conclusion
Collect Solution creates only one expensive suffle step (during the groupBy), and should be chosen in priority if your brands
dataframe is small. Join solution works if your brands
dataframe is big, but it creates lot of expensive suffle steps, with one groupBy and one join.