Home > Software engineering >  Filter array values during aggregation in spark dataframe
Filter array values during aggregation in spark dataframe

Time:09-28

I am performing an aggregation on the following dataframe to get a list of advertisers with array of brands

 ------------ ------ 
|advertiser  |brand |
 ------------ ------ 
|Advertiser 1|Brand1|
|Advertiser 1|Brand2|
|Advertiser 2|Brand3|
|Advertiser 2|Brand4|
|Advertiser 3|Brand5|
|Advertiser 3|Brand6|
 ------------ ------ 

Here is my code:

import org.apache.spark.sql.functions.collect_list

df2
  .groupBy("advertiser")
  .agg(collect_list("brand").as("brands"))

That gives me the following dataframe:

 ------------ ---------------- 
|advertiser  |brands          |
 ------------ ---------------- 
|Advertiser 1|[Brand1, Brand2]|
|Advertiser 2|[Brand3, Brand4]|
|Advertiser 3|[Brand5, Brand6]|
 ------------ ---------------- 

During the aggregation, I want to filter the list of brands with the following table of brands :

 ------ ------------ 
|brand |brand name  |
 ------ ------------ 
|Brand1|Brand_name_1|
|Brand3|Brand_name_3|
 ------ ------------ 

In order to achieve:

 ------------ -------- 
|advertiser  |brands  |
 ------------ -------- 
|Advertiser 1|[Brand1]|
|Advertiser 2|[Brand3]|
|Advertiser 3|null    |
 ------------ -------- 

CodePudding user response:

I see two solutions for your question, that I will call Collect Solution and Join Solution

Collect solution

If you can collect your brands dataframe, you can use this collected collection to keep only right brands when performing collect_list, then flatten your array and replace empty array by null as follow:

import org.apache.spark.sql.functions.{array, col, collect_list, flatten, size, when}

val filteredBrands = brands.select("brand").collect().map(_.getString(0))

val finalDataframe = df2
  .groupBy("advertiser")
  .agg(collect_list(when(col("brand").isin(filteredBrands: _*), array(col("brand"))).otherwise(array())).as("brands"))
  .withColumn("brands", flatten(col("brands")))
  .withColumn("brands", when(size(col("brands")).equalTo(0), null).otherwise(col("brands")))

Join solution

If your brands dataframe doesn't fit in memory, you can first left join df2 with brands to have a column containing brand if the brand is in brands dataframe, else null, then do your group by, and finally replace empty array due to advertisers without brands you want to filter by null:

import org.apache.spark.sql.functions.{col, collect_list}

val finalDataframe = df2
  .join(brands.select(col("brand").as("filtered_brand")), col("filtered_brand") === col("brand"), "left_outer")
  .groupBy("advertiser").agg(collect_list(col("filtered_brand")).as("brands"))
  .withColumn("brands", when(size(col("brands")).equalTo(0), null).otherwise(col("brands")))

Details

So if we start with a df2 dataframe as follow:

 ------------ ------ 
|advertiser  |brand |
 ------------ ------ 
|Advertiser 1|Brand1|
|Advertiser 1|Brand2|
|Advertiser 2|Brand3|
|Advertiser 2|Brand4|
|Advertiser 3|Brand5|
|Advertiser 3|Brand6|
 ------------ ------ 

And a brands dataframe as follow:

 ------ ------------ 
|brand |brand name  |
 ------ ------------ 
|Brand1|Brand_name_1|
|Brand3|Brand_name_3|
 ------ ------------ 

After the first left outer join between df2 and brands dataframes (first line), you get the following dataframe:

 ------------ ------ -------------- 
|advertiser  |brand |filtered_brand|
 ------------ ------ -------------- 
|Advertiser 1|Brand1|Brand1        |
|Advertiser 1|Brand2|null          |
|Advertiser 2|Brand3|Brand3        |
|Advertiser 2|Brand4|null          |
|Advertiser 3|Brand5|null          |
|Advertiser 3|Brand6|null          |
 ------------ ------ -------------- 

When you group this dataframe by advertiser, collecting list of filtered brands, you get the following dataframe:

 ------------ -------- 
|advertiser  |brands  |
 ------------ -------- 
|Advertiser 2|[Brand3]|
|Advertiser 3|[]      |
|Advertiser 1|[Brand1]|
 ------------ -------- 

And finally, when you apply last line replacing empty array with null, you get your expected result:

 ------------ -------- 
|advertiser  |brands  |
 ------------ -------- 
|Advertiser 2|[Brand3]|
|Advertiser 3|null    |
|Advertiser 1|[Brand1]|
 ------------ -------- 

Conclusion

Collect Solution creates only one expensive suffle step (during the groupBy), and should be chosen in priority if your brands dataframe is small. Join solution works if your brands dataframe is big, but it creates lot of expensive suffle steps, with one groupBy and one join.

  • Related