Home > Mobile >  Spark GroupBy to map and keep all fields
Spark GroupBy to map and keep all fields

Time:07-11

I want to groupby in order to map a Map.

case class Foo(
                          domain: String,
                          site: Long,
                          reason: String,
                          list: Long,
                          policy: Long,
                          action: Long,
                         )

val f = Seq(
Foo("domain", 1, "reason1", 33, 44, 55),
  Foo("domain", 1, "reason2", 33, 44, 55),
  Foo("domain", 2, "reason3", 33, 66, 55),
  Foo("domain", 2, "reason4", 33, 66, 55),
  Foo("domain", 1, "reason5", 33, 88, 55),
)


val ds = f.toDS()

I know I can get the fields explicitly by:

import org.apache.spark.sql.functions.{map, collect_list}

ds.groupBy($"site").agg(collect_list(map($"site", $"reason"))).collect()

But I want all fields (besides the aggregated one - site, basically site maps to all other fields in the current row) - my original case class is much larger than this, and might change. Any other way to do so without getting all col.

Any cleaner way to do so?

Using Spark 3, Scala 2.12

CodePudding user response:

Since the set of fields should be known dynamically, I guess one way to go would be to read it from the dataset schema after Spark has inferred it:

val allFields = ds.schema.fields.map(st => col(st.name))

And since site should map to all fields, one simple way to contain them would be to bundle them into a struct at the place where you put the $"reason", like this:

import org.apache.spark.sql.functions.struct

ds.groupBy($"site").agg(collect_list(map($"site", struct(allFields:_*)))).show(truncate=false)

which yields:

 ---- --------------------------------------------------------------------------------------------------------------------------- 
|site|collect_list(map(site, struct(domain, site, reason, list, policy, action)))                                                |
 ---- --------------------------------------------------------------------------------------------------------------------------- 
|1   |[{1 -> {domain, 1, reason1, 33, 44, 55}}, {1 -> {domain, 1, reason2, 33, 44, 55}}, {1 -> {domain, 1, reason5, 33, 88, 55}}]|
|2   |[{2 -> {domain, 2, reason3, 33, 66, 55}}, {2 -> {domain, 2, reason4, 33, 66, 55}}]                                         |
 ---- --------------------------------------------------------------------------------------------------------------------------- 

Although maybe the site in the map is redundant since it's also in the group-by key? This is a simpler version that just builds an array of struct:

ds.groupBy($"site").agg(collect_list(struct(allFields:_*)).as("foos")).show(truncate=false)

Which yields:

 ---- ------------------------------------------------------------------------------------------------------ 
|site|foos                                                                                                  |
 ---- ------------------------------------------------------------------------------------------------------ 
|1   |[{domain, 1, reason1, 33, 44, 55}, {domain, 1, reason2, 33, 44, 55}, {domain, 1, reason5, 33, 88, 55}]|
|2   |[{domain, 2, reason3, 33, 66, 55}, {domain, 2, reason4, 33, 66, 55}]                                  |
 ---- ------------------------------------------------------------------------------------------------------ 
  • Related