Home > database >  Grouping by then applying custom function in Spark, using SparkSession in a Java stream?
Grouping by then applying custom function in Spark, using SparkSession in a Java stream?

Time:09-02

Let's assume that I have a use case where I want to groupBy then apply a custom function to the grouped values. In Python, I could accomplish this through:

df.groupby('id').apply(custom_function)

and

@pandas_udf("id string, prediction double", PandasUDFType.GROUPED_MAP)
def custom_function(id, dataframe):
    rf = RandomForestRegressor(n_estimators=25, random_state=42)

    rf.fit(train_features, dataframe.quantity_sold)

    prediction = rf.predict(test_features)

    return pd.DataFrame({'id': id, 'prediction': prediction}, index=[0])

I could accomplish the same thing in Scala through:

input.rdd.groupBy(row => row.get(0)).collect().map(data => {
            val df = sparkSession.createDataFrame(sparkContext.parallelize(data._2.toSeq), input.schema)
 
            (data._1.toString, df)
        }).foldLeft(sparkSession.createDataFrame(sparkContext.emptyRDD[Row], outputSchema))((acc, next) => {
            val assembler = new VectorAssembler()
                .setInputCols(modelColumns)
                .setOutputCol(features)
                .transform(next._2)
 
            val forest = oldForest
                .fit(assembler)
                .transform(testAssembler)
 
            acc.union(forest)
        })

If we compare these two workarounds, the upper one works much faster than the below one. I tried to do this without collect, but I get the error RDD transformations and actions can only be invoked by the driver, not inside of other transformations.

I am aware that collect returns the results to the driver as a list, that is why I am forced to use Scala collection API (map and flatMap) to further continue my processing.

My questions regarding this are, is the job not supposed to be spread to executors again once collected to the driver (since I am continuing to use Spark ML API)? Or is everything simply calculated (once collected) in the driver as the code goes back to where main method is executed? Basically, why is the run very slow and is there any approach to make this process better without using Python?

Thank you!

CodePudding user response:

The python version actually run in the executors and hence distributes the load. Collect requires all executors to send their data to the driver for processing. This means your only use the threads provided by the driver. You are also likely suffering from lots of garbage collection as well as you are creating a Vector assembler over and over again.(and immediately throwing it away)

If you want you can do collect like things in-side of an executor. you can use mapPartitions.

val df4 = df2.mapPartitions(iterator => { // Start executer code
    // Do the heavy initialization here
    // Like database connections e.t.c
    val util = new Util()
    val res = iterator.map(row=>{ 
      val fullName = util.combine(row.getString(0),row.getString(1),row.getString(2))
      (fullName, row.getString(3),row.getInt(5)) 
    })
    res // End executor code
  })
  val df4part = df4.toDF("fullName","id","salary")
  df4part.printSchema()
  df4part.show(false)

The catch is that you cannot use any feature that uses sparkContext as that only lives inside the driver. Said another way: You can only use pure scala features inside the executor code. But if you can find a Scala library for Random forest that would be answer. The iterator used inside is very memory efficient and will run much faster than your collect that you are doing.

Likely you really want to use spark's RandomForestRegressor?

It look like you have a global oldForest so I can't tell what you are using but [a global variable] won't work with mapParitions so initialize it once and use it many times(inside the executor code)

Collect Code

CodePudding user response:

This is a good use case for an User-Defined Aggregate Function.

What is an User-Defined Aggregate Function?

After grouping a dataframe with groupBy usually one or more aggregation functions like min, max or sum are used to aggregate all values that belong to one group of rows into a single value. If none of Spark's built-in functions suits your needs you can write your own function that takes the data from one of the groups and aggregates it into a new value.

Like you can use

df.groupBy('myCol1).agg(sum('myCol2))

you can use

df.groupBy('myCol1).agg(customFunction('myCol2))

where customFunction does whatever you need it to do, for example applying a RandomForestRegressor to all elements of one group of data.

How to create an User-Defined Aggregate Function?

Here is an (arguably simplistic) example for an User-Defined Aggregate Function. This function collects all values of one group in a sequence and then concatenates all these values into a string.

import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import spark.implicits._

//some test data: 1,2,3,...,10
val df = (1 to 10).toDF()

//create the user defined aggregation function
object MyAgg extends Aggregator[Int, Seq[Int], String]{
  override def zero: Seq[Int] = scala.collection.mutable.Seq[Int]()

  override def reduce(b: Seq[Int], a: Int): Seq[Int] = b :  a

  override def merge(b1: Seq[Int], b2: Seq[Int]): Seq[Int] = b1    b2

  override def finish(allInts: Seq[Int]): String = allInts.foldLeft("START")((s,b) => s   "_"   b)

  override def bufferEncoder: Encoder[Seq[Int]] = newSequenceEncoder[Seq[Int]]

  override def outputEncoder: Encoder[String] = Encoders.STRING
}
val myAggFct = udaf(MyAgg).withName("myAgg")

//group the dataframe and apply myAggFct to each group separately
df.groupBy(expr("value % 3")).agg(myAggFct('value)).show

Output:

 ----------- -------------- 
|(value % 3)|  myagg(value)|
 ----------- -------------- 
|          1|START_1_4_7_10|
|          2|   START_2_5_8|
|          0|   START_3_6_9|
 ----------- -------------- 

How does the User-Defined Aggregate Function work?

The two functions reduce and merge combine all values of one group into a sequence created by the zero function.

The central function is the function finish. Here the sequence of all collected values (allInts) is transformed into the result of the aggregation operation. This would be the place to apply for example the RandomForestRegressor. As the finish function runs distributed on the executor nodes, all required additional data should be broadcasted.

Note: the example above could also (better) be implemented using Dataset.reduce because we do not need the values as sequence. We simply could add the values to the string as soon as we see them. But for a regressor we need the complete list of values and so the User-Defined Aggreate Function is reasonable here.

  • Related