Home > Software design >  What's the best way to group and aggregate an array of objects in a dataframe in scala
What's the best way to group and aggregate an array of objects in a dataframe in scala

Time:02-02

An example:
_4 is a collection of count, date and tag that I want to group and sum

|_1 |_2   |_3|_4                                                            |
|100|Scrap|12|{[{1, 2022-12-05, A}, {1, 2022-12-05, B}]}                    |
|100|Scrap|12|{[{1, 2022-12-06, A}]}                                        |
|100|Scrap|15|{[{2, 2022-12-07, A}, {2, 2022-12-02, A}, {2, 2022-12-03, C}]}|
|100|Scrap|15|{[{5, 2022-12-05, A}, {3, 2022-12-05, A}, {5, 2022-12-05, D}]}|

The output I'm hoping for is something like this which groups by the first 3 columns and the third element (tag) in the objects while summing the first element (count).

|UID |Title|Cell|Data                 |
|100 |Scrap|12  |{[{2,A},{1,B}]       |
|100 |Scrap|15  |{[{12,A},{2,C},{5,D}]|

schema of the dataframe looks like this

|-- _1: long (nullable = false)
 |-- _2: string (nullable = true)
 |-- _3: long (nullable = false)
 |-- _4: struct (nullable = true)
 |    |-- data: array (nullable = true)
 |    |    |-- element: struct (containsNull = true)
 |    |    |    |-- count: integer (nullable = false)
 |    |    |    |-- date: date (nullable = true)
 |    |    |    |-- tag: string (nullable = true)

CodePudding user response:

A straight forward approach would be to flatten the array content of column _4 via inline, followed by a couple of groupBy/agg as shown below:

import java.sql.Date
case class Item(count: Int, date: Date, tag: String)
case class Items(data: Seq[Item])

val df = Seq(
  (100L, "Scrap", 12L, Items(Seq(Item(1, Date.valueOf("2022-12-05"), "A"), Item(1, Date.valueOf("2022-12-05"), "B")))),
  (100L, "Scrap", 12L, Items(Seq(Item(1, Date.valueOf("2022-12-06"), "A")))),
  (100L, "Scrap", 15L, Items(Seq(Item(2, Date.valueOf("2022-12-07"), "A"), Item(2, Date.valueOf("2022-12-02"), "A"), Item(2, Date.valueOf("2022-12-03"), "C")))),
  (100L, "Scrap", 15L, Items(Seq(Item(5, Date.valueOf("2022-12-05"), "A"), Item(3, Date.valueOf("2022-12-05"), "A"), Item(5, Date.valueOf("2022-12-05"), "D"))))
).toDF("_1", "_2", "_3", "_4")

df.
  select($"_1", $"_2", $"_3", expr("inline(_4.data)")).
  groupBy($"_1".as("UID"), $"_2".as("Title"), $"_3".as("Cell"), $"tag").agg(
    struct(sum($"count"), first($"tag")).as("TagSum")
  ).
  groupBy("UID", "Title", "Cell").agg(
    collect_list("TagSum").as("Data")
  ).
  show(false)
/*
 --- ----- ---- ------------------------- 
|UID|Title|Cell|Data                     |
 --- ----- ---- ------------------------- 
|100|Scrap|12  |[{1, B}, {2, A}]         |
|100|Scrap|15  |[{2, C}, {12, A}, {5, D}]|
 --- ----- ---- ------------------------- 
*/

The 1st groupBy groups the dataset by the key columns along with the struct field tag of _4.data elements to sum the count by tag, and the 2nd groupBy groups only by the key columns to aggregate for the wanted result.

  • Related