Home > Enterprise >  Spark groupBy X then sortBy Y then get topK
Spark groupBy X then sortBy Y then get topK

Time:12-11

case class Tomato(name:String, rank:Int)
case class Potato(..)

I have Spark 2.4 and Dataset[Tomato, Potato] that I want to groupBy name and get topK ranks. Issue is that groupBy produces an iterator which is not sortable and iterator.toList explodes on large datasets.

Iterator solution:

  data.groupByKey{ case (tomato,_) => tomato.name }
        .flatMapGroups((k,it)=>it.toList.sortBy(_.rank).take(topK)) 

I've also tried aggregation functions but I could not find a topK or firstK only first and last. Another thing I hate about aggregation functions is that they convert the dataset to a dataframe (yuck) so all the types are gone.

Aggregation Fn solution syntax made up by me:

data.agg(row_number.over(Window.partitionBy("_1.name").orderBy("_1.rank").take(topK))

There are already several questions on SO that ask for groupBy then some other operation but none want to sort by a key different than the groupBy key and then get topK

CodePudding user response:

You could go the iterator route without having to create a full list which indeed explodes with big datasets. Something like:

import spark.implicits._
import scala.util.Sorting

case class Tomato(name:String, rank:Int)
case class Potato(taste: String)
case class MyClass(tomato: Tomato, potato: Potato)
val ordering = Ordering.by[MyClass, Int](_.tomato.rank)

val ds = Seq(
  (MyClass(Tomato("tomato1", 1), Potato("tasty"))),
  (MyClass(Tomato("tomato1", 2), Potato("tastier"))),
  (MyClass(Tomato("tomato2", 2), Potato("tastiest"))),
  (MyClass(Tomato("tomato3", 2), Potato("yum"))),
  (MyClass(Tomato("tomato3", 4), Potato("yummier"))),
  (MyClass(Tomato("tomato3", 50), Potato("yummiest"))),
  (MyClass(Tomato("tomato7", 50), Potato("yam")))
).toDS

val k = 2
val output = ds
  .groupByKey{
    case MyClass(tomato, potato) => tomato.name
  }
  .mapGroups(
    (name, iterator)=> {
      val topK = iterator.foldLeft(Seq.empty[MyClass]){
        (accumulator, element) => {
          val newAccumulator = accumulator :  element
          if (newAccumulator.length > k)
            newAccumulator.sorted(ordering).drop(1)
          else
            newAccumulator
        }
      }
      (name, topK)
    }
  )


output.show(false)                                                                                                                                                                                                                                                       
 ------- --------------------------------------------------------                                                                                                                                                                                                               
|_1     |_2                                                      |                                                                                                                                                                                                              
 ------- --------------------------------------------------------                                                                                                                                                                                                               
|tomato7|[[[tomato7, 50], [yam]]]                                |                                                                                                                                                                                                              
|tomato2|[[[tomato2, 2], [tastiest]]]                            |                                                                                                                                                                                                              
|tomato1|[[[tomato1, 1], [tasty]], [[tomato1, 2], [tastier]]]    |                                                                                                                                                                                                              
|tomato3|[[[tomato3, 4], [yummier]], [[tomato3, 50], [yummiest]]]|                                                                                                                                                                                                              
 ------- -------------------------------------------------------- 

So as you see, for each Tomato.name key, we're keeping the k elements with the largest Tomato.rank values. You get a Dataset[(String, Seq(MyClass))] as result.

This is not really optimized for performance: for each group, we're iterating over all of its elements and sorting the sequence which could become quite intensive computationally. But this all depends on the size of your actual case classes, the size of your data, your requirements, ...

Hope this helps!

CodePudding user response:

Issue is that groupBy produces an iterator which is not sortable and iterator.toList explodes on large datasets.

What you could do is to come up with a topK() method that takes parameters k, Iterator[A] and a A => B mapping to return an Iterator[A] of top k elements (sorted by value of type B) -- all without having to sort the entire iterator:

def topK[A, B : Ordering](k: Int, iter: Iterator[A], f: A => B): Iterator[A] = {
  val orderer = implicitly[Ordering[B]]
  import orderer._
  val listK = iter.take(k).toList
  iter.foldLeft(listK.sortWith(f(_) > f(_))){ (lsK, x) =>
    if (f(x) < f(lsK.head))
      (x :: lsK.tail).sortWith(f(_) > f(_))
    else
      lsK
  }.reverse.iterator
}

Note that topK() only involves iterative sorting of lists of size k, with the assumption k is small compared with the size of the input iterator. If necessary, it could be further optimized to eliminate the sorting of the k-elements lists by only making its first element the largest element while leaving the rest of the lists unsorted.

Using your groupByKey approach, method topK() can be plugged in within flatMapGroups as shown below:

case class T(name: String, rank: Int)
case class P(name: String, rank: Int)

val ds = Seq(
  (T("t1", 4), P("p1", 1)),
  (T("t1", 5), P("p2", 2)),
  (T("t1", 1), P("p3", 3)),
  (T("t1", 3), P("p4", 4)),
  (T("t1", 2), P("p5", 5)),
  (T("t2", 4), P("p6", 6)),
  (T("t2", 2), P("p7", 7)),
  (T("t2", 6), P("p8", 8))
).toDF("tomato", "potato").as[(T, P)]

val k = 3

ds.
  groupByKey{ case (tomato, _) => tomato.name }.
  flatMapGroups((_, it) => topK[(T, P), Int](k, it, { case (t, p) => t.rank })).
  show
/*
 ------- ------- 
|     _1|     _2|
 ------- ------- 
|{t1, 1}|{p3, 3}|
|{t1, 2}|{p5, 5}|
|{t1, 3}|{p4, 4}|
|{t2, 2}|{p7, 7}|
|{t2, 4}|{p6, 6}|
|{t2, 6}|{p8, 8}|
 ------- ------- 
*/
  • Related