Home > Software engineering >  Scala getting average of the result
Scala getting average of the result

Time:10-28

Hi I am trying to calculate the average of the movie result from this tsv set

running time        Genre
    1               Documentary,Short
    5               Animation,Short
    4               Animation,Comedy,Romance

Animation is one type of Genre and same goes for Short, Comedy, Romance

I'm new to Scala and I'm confused about how to get an Average as per each genre using Scala without any immutable functions

I tried using this below snippet to just try some sort of iterations and get the runTimes as per each genre

val a = list.foldLeft(Map[String,(Int)]()){
      case (map,arr) =>{
      map   (arr.genres.toString ->(arr.runtimeMinutes))
    }}

Is there any way to calculate the average

CodePudding user response:

Assuming the data was already parsed into something like:

final case class Row(runningTime: Int, genres: List[String])

Then you can follow a declarative approach to compute your desired result.

  1. Flatten a List[Row] into a list of pairs, where the first element is a genre and the second element is a running time.
  2. Collect all running times for the same genre.
  3. Reduce each group to compute its average.
def computeAverageRunningTimePerGenre(data: List[Row]): Map[String, Double] =
  data.flatMap {
    case Row(runningTime, genres) =>
      genres.map(genre => genre -> runningTime)
  }.groupMap(_._1)(_._2).view.mapValues { runningTimes =>
    runningTimes.sum.toDouble / runningTimes.size.toDouble
  }.toMap

You can see the code running here.

CodePudding user response:

You can do this pretty simply using SQL Functions. First explode each row so that it contains a single value for each genre, then group by it and calculate the average.

import sparkSession.implicits._
import org.apache.spark.sql.functions._

  val xs =
    Seq((1, Seq("Documentary", "Short")), (5, Seq("Animation", "Short")), (4, Seq("Animation", "Comedy", "Romance")))
      .toDF("runningTime", "genres")

  val ys = xs
    .select('runningTime, explode('genres) as "genre")
    .groupBy('genre)
    .agg(avg('runningTime) as "averageRunningTime")

  xs.show(false)
  ys.show(false)
  • Related