case class Tomato(name:String, rank:Int)
case class Potato(..)
I have Spark 2.4 and Dataset[Tomato, Potato]
that I want to groupBy
name and get topK
ranks.
Issue is that groupBy
produces an iterator
which is not sortable and iterator.toList
explodes on large datasets.
Iterator solution:
data.groupByKey{ case (tomato,_) => tomato.name }
.flatMapGroups((k,it)=>it.toList.sortBy(_.rank).take(topK))
I've also tried aggregation functions but I could not find a topK or firstK only first and last. Another thing I hate about aggregation functions is that they convert the dataset to a dataframe (yuck) so all the types are gone.
Aggregation Fn solution syntax made up by me:
data.agg(row_number.over(Window.partitionBy("_1.name").orderBy("_1.rank").take(topK))
There are already several questions on SO that ask for groupBy then some other operation but none want to sort by a key different than the groupBy key and then get topK
CodePudding user response:
You could go the iterator route without having to create a full list which indeed explodes with big datasets. Something like:
import spark.implicits._
import scala.util.Sorting
case class Tomato(name:String, rank:Int)
case class Potato(taste: String)
case class MyClass(tomato: Tomato, potato: Potato)
val ordering = Ordering.by[MyClass, Int](_.tomato.rank)
val ds = Seq(
(MyClass(Tomato("tomato1", 1), Potato("tasty"))),
(MyClass(Tomato("tomato1", 2), Potato("tastier"))),
(MyClass(Tomato("tomato2", 2), Potato("tastiest"))),
(MyClass(Tomato("tomato3", 2), Potato("yum"))),
(MyClass(Tomato("tomato3", 4), Potato("yummier"))),
(MyClass(Tomato("tomato3", 50), Potato("yummiest"))),
(MyClass(Tomato("tomato7", 50), Potato("yam")))
).toDS
val k = 2
val output = ds
.groupByKey{
case MyClass(tomato, potato) => tomato.name
}
.mapGroups(
(name, iterator)=> {
val topK = iterator.foldLeft(Seq.empty[MyClass]){
(accumulator, element) => {
val newAccumulator = accumulator : element
if (newAccumulator.length > k)
newAccumulator.sorted(ordering).drop(1)
else
newAccumulator
}
}
(name, topK)
}
)
output.show(false)
------- --------------------------------------------------------
|_1 |_2 |
------- --------------------------------------------------------
|tomato7|[[[tomato7, 50], [yam]]] |
|tomato2|[[[tomato2, 2], [tastiest]]] |
|tomato1|[[[tomato1, 1], [tasty]], [[tomato1, 2], [tastier]]] |
|tomato3|[[[tomato3, 4], [yummier]], [[tomato3, 50], [yummiest]]]|
------- --------------------------------------------------------
So as you see, for each Tomato.name
key, we're keeping the k
elements with the largest Tomato.rank
values. You get a Dataset[(String, Seq(MyClass))]
as result.
This is not really optimized for performance: for each group, we're iterating over all of its elements and sorting the sequence which could become quite intensive computationally. But this all depends on the size of your actual case classes, the size of your data, your requirements, ...
Hope this helps!
CodePudding user response:
Issue is that groupBy produces an iterator which is not sortable and iterator.toList explodes on large datasets.
What you could do is to come up with a topK()
method that takes parameters k
, Iterator[A]
and a A => B
mapping to return an Iterator[A]
of top k
elements (sorted by value of type B
) -- all without having to sort the entire iterator:
def topK[A, B : Ordering](k: Int, iter: Iterator[A], f: A => B): Iterator[A] = {
val orderer = implicitly[Ordering[B]]
import orderer._
val listK = iter.take(k).toList
iter.foldLeft(listK.sortWith(f(_) > f(_))){ (lsK, x) =>
if (f(x) < f(lsK.head))
(x :: lsK.tail).sortWith(f(_) > f(_))
else
lsK
}.reverse.iterator
}
Note that topK()
only involves iterative sorting of lists of size k
, with the assumption k
is small compared with the size of the input iterator. If necessary, it could be further optimized to eliminate the sorting of the k-elements lists by only making its first element the largest element while leaving the rest of the lists unsorted.
Using your groupByKey
approach, method topK()
can be plugged in within flatMapGroups
as shown below:
case class T(name: String, rank: Int)
case class P(name: String, rank: Int)
val ds = Seq(
(T("t1", 4), P("p1", 1)),
(T("t1", 5), P("p2", 2)),
(T("t1", 1), P("p3", 3)),
(T("t1", 3), P("p4", 4)),
(T("t1", 2), P("p5", 5)),
(T("t2", 4), P("p6", 6)),
(T("t2", 2), P("p7", 7)),
(T("t2", 6), P("p8", 8))
).toDF("tomato", "potato").as[(T, P)]
val k = 3
ds.
groupByKey{ case (tomato, _) => tomato.name }.
flatMapGroups((_, it) => topK[(T, P), Int](k, it, { case (t, p) => t.rank })).
show
/*
------- -------
| _1| _2|
------- -------
|{t1, 1}|{p3, 3}|
|{t1, 2}|{p5, 5}|
|{t1, 3}|{p4, 4}|
|{t2, 2}|{p7, 7}|
|{t2, 4}|{p6, 6}|
|{t2, 6}|{p8, 8}|
------- -------
*/