Home > other >  Scala - divide the dataset into dataset of arrays with a fixed size
Scala - divide the dataset into dataset of arrays with a fixed size

Time:10-13

I have a function whose purpose is to divide a dataset into arrays of a given size.
For example - I have a dataset with 123 objects of the Foo type, I provide to the function arraysSize 10 so as a result I will have a Dataset[Array[Foo]] with 12 arrays with 10 Foo's and 1 array with 3 Foo. Right now function is working on collected data - I would like to change it on dataset based because of performance but I dont know how.
This is my current solution:

  private def mapToFooArrays(data: Dataset[Foo],
                                         arraysSize: Int): Dataset[Array[Foo]]= {
data.collect().grouped(arraysSize).toSeq.toDS()
  }

The reason for doing this transformation is because the data will be sent in the event. Instead of sending 1 million events with information about 1 object, I prefer to send, for example, 10 thousand events with information about 100 objects

CodePudding user response:

IMO, this is a weird use case. I can not think of any efficient solution to do this, as it is going to require a lot of shuffling no matter how we do it.

But, the following is still better, as it avoids collecting to the driver node and will thus be more scalable.

Things to keep in mind -

  • what is the value of data.count() ?
  • what is the size of a single Foo ?
  • what is the value of arraySize ?
  • what is your executor configuration ?

Based on these factors you will be able to come up with the desiredArraysPerPartition value.


val desiredArraysPerPartition = 50

private def mapToFooArrays(
    data: Dataset[Foo],
    arraysSize: Int
): Dataset[Array[Foo]] = {
  val size = data.count()
  val numArrays = (size.toDouble / arrarySize).ceil
  val numPartitions = (numArrays.toDouble / desiredArraysPerPartition).ceil
  
  data
    .repartition(numPartitions)
    .mapPartitions(_.grouped(arrarySize).map(_.toArray))
}

After reading the edited part, I think that 100 size in 10 thousand events with information about 100 objects is not really important. As it is referred as about 100. There can be more than one events with less than 100 Foo's.

If we are not very strict about that 100 size, then there is no need of reshuffle.

We can locally group the Foo's present in each of the partitions. As this grouping is being done locally and not globally, this might result in more than one (potentially one for each partition) Arrays with less than 100 Foo's.

private def mapToFooArrays(
    data: Dataset[Foo],
    arraysSize: Int
): Dataset[Array[Foo]] = 
  data
    .mapPartitions(_.grouped(arrarySize).map(_.toArray))
  • Related