Home > Blockchain >  ClassCastException: org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema cannot be cast
ClassCastException: org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema cannot be cast

Time:06-29

I am using an Aggregator to apply some custom merge on a DataFrame after grouping its records by their primary key:

case class Player(
  pk: String, 
  ts: String, 
  first_name: String, 
  date_of_birth: String
)

case class PlayerProcessed(
  var ts: String, 
  var first_name: String, 
  var date_of_birth: String
)

// Cutomer Aggregator -This just for the example, actual one is more complex
object BatchDedupe extends Aggregator[Player, PlayerProcessed, PlayerProcessed] {

  def zero: PlayerProcessed = PlayerProcessed("0", null, null)

  def reduce(bf: PlayerProcessed, in : Player): PlayerProcessed = {
    bf.ts = in.ts
    bf.first_name = in.first_name
    bf.date_of_birth = in.date_of_birth
    bf
  }

  def merge(bf1: PlayerProcessed, bf2: PlayerProcessed): PlayerProcessed = {
    bf1.ts = bf2.ts
    bf1.first_name = bf2.first_name
    bf1.date_of_birth = bf2.date_of_birth
    bf1
  }

  def finish(reduction: PlayerProcessed): PlayerProcessed = reduction
  def bufferEncoder: Encoder[PlayerProcessed] = Encoders.product
  def outputEncoder: Encoder[PlayerProcessed] = Encoders.product
}


val ply1 = Player("12121212121212", "10000001", "Rogger", "1980-01-02")
val ply2 = Player("12121212121212", "10000002", "Rogg", null)
val ply3 = Player("12121212121212", "10000004", null, "1985-01-02")
val ply4 = Player("12121212121212", "10000003", "Roggelio", "1982-01-02")

val seq_users = sc.parallelize(Seq(ply1, ply2, ply3, ply4)).toDF.as[Player]

val grouped = seq_users.groupByKey(_.pk)

val non_sorted = grouped.agg(BatchDedupe.toColumn.name("deduped"))
non_sorted.show(false)

This returns:

 -------------- -------------------------------- 
|key           |deduped                         |
 -------------- -------------------------------- 
|12121212121212|{10000003, Roggelio, 1982-01-02}|
 -------------- -------------------------------- 

Now, I would like to order the records based on ts before aggregating them. From here I understand that .sortBy("ts") do not guarantee the order after the .groupByKey(_.pk). So I was trying to apply the .sortBy between the .groupByKey and the .agg

The output of the .groupByKey(_.pk) is a KeyValueGroupedDataset[String,Player], being the second element an Iterator. So to apply some sorting logic there I convert it into a Seq:

val sorted = grouped.mapGroups{case(k, iter) => (k, iter.toSeq.sortBy(_.ts))}.agg(BatchDedupe.toColumn.name("deduped"))
sorted.show(false)

However, the output of .mapGroups after adding the sorting logic is a Dataset[(String, Seq[Player])]. So when I try to invoke the .agg function on it I am getting the following exception:

Caused by: ClassCastException: org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema cannot be cast to $line050e0d37885948cd91f7f7dd9e3b4da9311.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$Player

How could I convert back the output of my .mapGroups(...) into a KeyValueGroupedDataset[String,Player]?

I tried to cast back to Iterator as follows:

val sorted = grouped.mapGroups{case(k, iter) => (k, iter.toSeq.sortBy(_.ts).toIterator)}.agg(BatchDedupe.toColumn.name("deduped"))

But this approach produced the following exception:

UnsupportedOperationException: No Encoder found for Iterator[Player]
- field (class: "scala.collection.Iterator", name: "_2")
- root class: "scala.Tuple2"

How else can I add the sort logic between the .groupByKey and .agg methods?

CodePudding user response:

Based on the discussion above, the purpose of the Aggregator is to get the latest field values per Player by ts ignoring null values.

This can be achieved fairly easily aggregating all fields individually using max_by. With that there's no need for a custom Aggregator nor the mutable aggregation buffer.

import org.apache.spark.sql.functions._

val players: Dataset[Player] = ...

// aggregate all columns except the key individually by ts
// NULLs will be ignored (SQL standard)
val aggColumns = players.columns
   .filterNot(_ == "pk")
   .map(colName => expr(s"max_by($colName, if(isNotNull($colName), ts, null))").as(colName))

val aggregatedPlayers = players
   .groupBy(col("pk"))
   .agg(aggColumns.head, aggColumns.tail: _*)
   .as[Player]

On the most recent versions of Spark you can also use the build in max_by expression:

import org.apache.spark.sql.functions._

val players: Dataset[Player] = ...

// aggregate all columns except the key individually by ts
// NULLs will be ignored (SQL standard)
val aggColumns = players.columns
   .filterNot(_ == "pk")
   .map(colName => max_by(col(colName), when(col(colName).isNotNull, col("ts"))).as(colName))

val aggregatedPlayers = players
   .groupBy(col("pk"))
   .agg(aggColumns.head, aggColumns.tail: _*)
   .as[Player]
  • Related