Hi I am trying to calculate the average of the movie result from this tsv set
running time Genre
1 Documentary,Short
5 Animation,Short
4 Animation,Comedy,Romance
Animation is one type of Genre and same goes for Short, Comedy, Romance
I'm new to Scala and I'm confused about how to get an Average as per each genre using Scala without any immutable functions
I tried using this below snippet to just try some sort of iterations and get the runTimes as per each genre
val a = list.foldLeft(Map[String,(Int)]()){
case (map,arr) =>{
map (arr.genres.toString ->(arr.runtimeMinutes))
}}
Is there any way to calculate the average
CodePudding user response:
Assuming the data was already parsed into something like:
final case class Row(runningTime: Int, genres: List[String])
Then you can follow a declarative approach to compute your desired result.
- Flatten a
List[Row]
into a list of pairs, where the first element is a genre and the second element is a running time. - Collect all running times for the same genre.
- Reduce each group to compute its average.
def computeAverageRunningTimePerGenre(data: List[Row]): Map[String, Double] =
data.flatMap {
case Row(runningTime, genres) =>
genres.map(genre => genre -> runningTime)
}.groupMap(_._1)(_._2).view.mapValues { runningTimes =>
runningTimes.sum.toDouble / runningTimes.size.toDouble
}.toMap
You can see the code running here.
CodePudding user response:
You can do this pretty simply using SQL Functions. First explode each row so that it contains a single value for each genre, then group by it and calculate the average.
import sparkSession.implicits._
import org.apache.spark.sql.functions._
val xs =
Seq((1, Seq("Documentary", "Short")), (5, Seq("Animation", "Short")), (4, Seq("Animation", "Comedy", "Romance")))
.toDF("runningTime", "genres")
val ys = xs
.select('runningTime, explode('genres) as "genre")
.groupBy('genre)
.agg(avg('runningTime) as "averageRunningTime")
xs.show(false)
ys.show(false)