I have a dataframe that contains an "id" column and a "publication" column. The "id" column contains duplicates, and represents a researcher. The "publication" column contains some information about an academic work the researcher published.
I want to transform this dataframe to collect the publications into an array, reducing the number of rows. I can do this using groupBy and collect_list. This would make it so that the "id" column would only contain unique values.
myDataframe
.groupBy("id")
.agg(
collect_list("publication").as("publications")
).select("id", "publications")
However, for my purposes, this is too much data for one row. I want to limit the number of publications that are collected, and split the data up across multiple rows.
Let's my dataframe looks like this, where id of 1 appears in 10 rows:
| id | publication |
| ----| -------------- |
| 1 | "foobar" |
| 1 | "foobar" |
| 1 | "foobar" |
| 1 | "foobar" |
| 1 | "foobar" |
| 1 | "foobar" |
| 2 | "foobar" |
| 1 | "foobar" |
| 1 | "foobar" |
| 1 | "foobar" |
| 1 | "foobar" |
I want to groupBy id and collect publication into a list, but limit this to a maximum of 5 publications per group:
| id | publication |
| ----| -------------- |
| 1 | ["foobar",...] |
| 1 | ["foobar",...] |
| 2 | ["foobar"] |
How would I accomplish this in spark scala?
CodePudding user response:
Add row_number() column in your df over a window with same keys as your group by
.withColumn(´col’, row_number().over(Window.partitionBy(‘id’))
Create new id with this row num modulo 5 or divide by 5 and truncate to integer
then group y on this
CodePudding user response:
If you want a fixed number of publications per row, you have to first calculate an intermediary number of buckets per researcher. For example, say you have two researchers where the first one has 100 publications and the second researcher has 2. In order to have five publications per list, the first researcher needs 20 buckets, but the second one only needs 1. You can determine the number of buckets by taking the number of publications divided by five (or however many publications you want per list). You would then take the row number modulo the bucket number and group on that. Here's an example I ran in spark-shell
:
val testDF = Seq((1, "pub1"),
(1, "pub1"),
(1, "pub2"),
(1, "pub3"),
(1, "pub4"),
(1, "pub5"),
(1, "pub6"),
(1, "pub7"),
(1, "pub8"),
(2, "pub9"),
(2, "pub10"),
(2, "pub11"),
(2, "pub12"),
(2, "pub13")).toDF("id", "publication")
testDF.withColumn("pubCount", count("*").over(Window.partitionBy("id")))
.withColumn("rn", row_number().over(Window.partitionBy("id").orderBy("id")))
.withColumn("bucket", col("rn") % ceil(col("pubCount")/5))
.groupBy("id", "bucket").agg(collect_list("publication").as("publications"))
.select("id", "publications")
.show(false)
Output:
--- ----------------------------------
|id |publications |
--- ----------------------------------
|1 |[pub1, pub2, pub4, pub6, pub8] |
|1 |[pub1, pub3, pub5, pub7] |
|2 |[pub9, pub10, pub11, pub12, pub13]|
--- ----------------------------------