Home > Back-end >  Spark Scala : how to group data based on a field with a certain range?
Spark Scala : how to group data based on a field with a certain range?

Time:07-17

I have below data frame in spark scala

DF_City

 ------------- ------------- ---------- ------------ 
|         city|     state   |year_month|saleCount   |
 ------------- ------------- ---------- ------------ 
|  Bangalore  |   Karnataka |   2020-01|          10|
|  Bangalore  |   Karnataka |   2020-02|          10|
|  Bangalore  |   Karnataka |   2021-03|          10|
|  Bangalore  |   Karnataka |   2021-04|          10|
|  Bangalore  |   Karnataka |   2021-05|          10|
|  Bangalore  |   Karnataka |   2021-06|          10|
|  Bangalore  |   Karnataka |   2021-07|          10|
|  Bangalore  |   Karnataka |   2021-08|          10|
|  Bangalore  |   Karnataka |   2021-09|          10|
|  Bangalore  |   Karnataka |   2021-10|          10|

|  Chennai    |   Tamil Nadu|   2020-05|          20|
|  Chennai    |   Tamil Nadu|   2020-06|          20|
|  Chennai    |   Tamil Nadu|   2020-07|          20|
|  Chennai    |   Tamil Nadu|   2020-08|          20|
|  Chennai    |   Tamil Nadu|   2020-09|          20|
|  Chennai    |   Tamil Nadu|   2020-10|          20|
|  Chennai    |   Tamil Nadu|   2020-11|          20|
 ------------- ------------- ---------- ------------ 

from the above I want this output, using scala spark (Data frame RDD)

Target DF :

|         city|     state   |year_month|saleCount   | last12MonthsSellCount  |
 ------------- ------------- ---------- ------------ ------------------------ 
|  Bangalore  |   Karnataka |   2020-01|          10| 10                     |
|  Bangalore  |   Karnataka |   2020-02|          10| 20                     | 
|  Bangalore  |   Karnataka |   2021-03|          10| 30                     |
|  Bangalore  |   Karnataka |   2021-04|          10| 40                     |
|  Bangalore  |   Karnataka |   2021-05|          10| 50                     |
|  Bangalore  |   Karnataka |   2021-06|          10| 60                     |
|  Bangalore  |   Karnataka |   2021-07|          10| 70                     |
|  Bangalore  |   Karnataka |   2021-08|          10| 80                     |
|  Bangalore  |   Karnataka |   2021-09|          10| 90                     |
|  Bangalore  |   Karnataka |   2021-10|          10|100                     |

|  Chennai    |   Tamil Nadu|   2020-05|          20| 20                     |
|  Chennai    |   Tamil Nadu|   2020-06|          20| 40                     |
|  Chennai    |   Tamil Nadu|   2020-07|          20| 60                     |
|  Chennai    |   Tamil Nadu|   2020-08|          20| 80                     |
|  Chennai    |   Tamil Nadu|   2020-09|          20|100                     |
|  Chennai    |   Tamil Nadu|   2020-10|          20|120                     |
|  Chennai    |   Tamil Nadu|   2020-11|          20|140                     |
 ------------- ------------- ---------- ------------ ------------------------ 

It will be a group inside a group means it has to create a sub-group, sub group with field 'year_month' based on a date range, this range is mainly creating the bottleneck, need help in this trick.

I have tried the below but it is not working

val cityStateMonthlyYearlycount = lastCityCountDF.withColumn("yearMonth", col("year_month")).groupBy(col("city"), col("state"), col("year_month")).agg(sum(when(datediff(col("year_month"), col("yearMonth")).leq(12), col("monthlyCount")).otherwise(0)).as("lastOneYearCount")).filter(col("lastOneYearCount") === 0).select("city", "state", "year_month", "lastOneYearCount")

what should be the correct approach for this ?? I wan to achieve this in scala spark, if possible only in data frame then it will be much better.

CodePudding user response:

Here is how you can do it with the Window function

val window = Window.partitionBy("city", "state")
  .rowsBetween(Window.unboundedPreceding, Window.currentRow)

val resultDF = df.withColumn("sum", sum("saleCount").over(window))
  .orderBy("city", "sum")

resultDF.show(false)

Output:

 --------- --------- ---------- --------- ----- 
|city     |state    |year_month|saleCount|sum  |
 --------- --------- ---------- --------- ----- 
|Bangalore|Karnataka|2020-01   |10       |10.0 |
|Bangalore|Karnataka|2020-02   |10       |20.0 |
|Bangalore|Karnataka|2021-03   |10       |30.0 |
|Bangalore|Karnataka|2021-04   |10       |40.0 |
|Bangalore|Karnataka|2021-05   |10       |50.0 |
|Bangalore|Karnataka|2021-06   |10       |60.0 |
|Bangalore|Karnataka|2021-07   |10       |70.0 |
|Bangalore|Karnataka|2021-08   |10       |80.0 |
|Bangalore|Karnataka|2021-09   |10       |90.0 |
|Bangalore|Karnataka|2021-10   |10       |100.0|
|Chennai  |TamilNadu|2020-05   |20       |20.0 |
|Chennai  |TamilNadu|2020-06   |20       |40.0 |
|Chennai  |TamilNadu|2020-07   |20       |60.0 |
|Chennai  |TamilNadu|2020-08   |20       |80.0 |
|Chennai  |TamilNadu|2020-09   |20       |100.0|
|Chennai  |TamilNadu|2020-10   |20       |120.0|
|Chennai  |TamilNadu|2020-11   |20       |140.0|
 --------- --------- ---------- --------- ----- 

CodePudding user response:

RDD in common case, as I know, more slow then DataFrame, but it is more powerfull. So it's not so hard to a create custom groupBy for last 12 months.

koiralo's answer is great, but he does not group by 12 last months.

Data initialization:

val monthlySalesList: Array[Sales] =
    """|Bangalore     |   Karnataka |   2010-01|          10|
       ||  Bangalore  |   Karnataka |   2020-01|          10|
       ||  Bangalore  |   Karnataka |   2020-02|          10|
       ||  Bangalore  |   Karnataka |   2021-03|          10|
       ||  Bangalore  |   Karnataka |   2021-04|          10|
       ||  Bangalore  |   Karnataka |   2021-05|          10|
       ||  Bangalore  |   Karnataka |   2021-06|          10|
       ||  Bangalore  |   Karnataka |   2021-07|          10|
       ||  Bangalore  |   Karnataka |   2021-08|          10|
       ||  Bangalore  |   Karnataka |   2021-09|          10|
       ||  Bangalore  |   Karnataka |   2021-10|          10|
       ||  Chennai    |   Tamil Nadu|   2019-05|          20|
       ||  Chennai    |   Tamil Nadu|   2020-05|          20|
       ||  Chennai    |   Tamil Nadu|   2020-06|          20|
       ||  Chennai    |   Tamil Nadu|   2020-07|          20|
       ||  Chennai    |   Tamil Nadu|   2020-08|          20|
       ||  Chennai    |   Tamil Nadu|   2020-09|          20|
       ||  Chennai    |   Tamil Nadu|   2020-10|          20|
       ||  Chennai    |   Tamil Nadu|   2020-11|          20|""".stripMargin
      .split("\n").map(_.split("\\|").map(_.trim).filter(_.nonEmpty)).map(ar =>
      Sales(ar(0), ar(1), ar(2), ar(3).toInt)
    )

Code:

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession

case class Sales(city: String, state: String, yearMonth: String, count: Int) {
  val Array(year, month) = yearMonth.split("-").map(_.toInt)
  val mothsSince1970: Int = (year - 1970) * 12   month - 1
}

case class Place(city: String, state: String)

object So73002153 extends App {
  val spark: SparkSession = SparkSession.builder.config("spark.master", "local").getOrCreate()
  val sc: SparkContext = spark.sparkContext
  val monthlySalesDf: RDD[Sales] = sc.parallelize(monthlySalesList)
  val salesGroupedByPlace: RDD[(Place, Iterable[Sales])] = monthlySalesDf.groupBy(s => Place(s.state, s.city))
  val salesWithSum: RDD[(Place, Array[(String, Int, Int)])] = salesGroupedByPlace.map { case (place, list) =>
    val salesForPlace: Array[Sales] = list.toArray.sortBy(_.monthsSince1970)
    val salesFor12Months: Array[Int] = Array.ofDim(salesForPlace.length)
    var lastFrom = 0
    for (i <- 0 until salesForPlace.length) {
      val firstIndex: Int = salesForPlace.indexWhere(s => s.monthsSince1970   12 >= salesForPlace(i).monthsSince1970, lastFrom)
      lastFrom = firstIndex
      var sum = 0
      // sum can be optimized, but it's not so easy
      for (j <- firstIndex to i) sum  = salesForPlace(j).count
      salesFor12Months(i) = sum
    }
    place -> salesForPlace.zip(salesFor12Months).map { case (s, sum) =>
      (s.yearMonth, s.count, sum)
    }
  }
  val resultRdd: RDD[(String, String, String, Int, Int)] = salesWithSum.flatMap { case (place, list) =>
    list.map(el => (place.city, place.state, el._1, el._2, el._3))
  }
  resultRdd.foreach(println)
}

Result (not formatted):

(Karnataka,Bangalore,2010-01,10,10)
(Karnataka,Bangalore,2020-01,10,10)
(Karnataka,Bangalore,2020-02,10,20)
(Karnataka,Bangalore,2021-03,10,10)
(Karnataka,Bangalore,2021-04,10,20)
(Karnataka,Bangalore,2021-05,10,30)
(Karnataka,Bangalore,2021-06,10,40)
(Karnataka,Bangalore,2021-07,10,50)
(Karnataka,Bangalore,2021-08,10,60)
(Karnataka,Bangalore,2021-09,10,70)
(Karnataka,Bangalore,2021-10,10,80)

and

(Tamil Nadu,Chennai,2019-05,20,20)
(Tamil Nadu,Chennai,2020-05,20,40)
(Tamil Nadu,Chennai,2020-06,20,40)
(Tamil Nadu,Chennai,2020-07,20,60)
(Tamil Nadu,Chennai,2020-08,20,80)
(Tamil Nadu,Chennai,2020-09,20,100)
(Tamil Nadu,Chennai,2020-10,20,120)
(Tamil Nadu,Chennai,2020-11,20,140)
  • Related