I have below data frame in spark scala
DF_City
------------- ------------- ---------- ------------
| city| state |year_month|saleCount |
------------- ------------- ---------- ------------
| Bangalore | Karnataka | 2020-01| 10|
| Bangalore | Karnataka | 2020-02| 10|
| Bangalore | Karnataka | 2021-03| 10|
| Bangalore | Karnataka | 2021-04| 10|
| Bangalore | Karnataka | 2021-05| 10|
| Bangalore | Karnataka | 2021-06| 10|
| Bangalore | Karnataka | 2021-07| 10|
| Bangalore | Karnataka | 2021-08| 10|
| Bangalore | Karnataka | 2021-09| 10|
| Bangalore | Karnataka | 2021-10| 10|
| Chennai | Tamil Nadu| 2020-05| 20|
| Chennai | Tamil Nadu| 2020-06| 20|
| Chennai | Tamil Nadu| 2020-07| 20|
| Chennai | Tamil Nadu| 2020-08| 20|
| Chennai | Tamil Nadu| 2020-09| 20|
| Chennai | Tamil Nadu| 2020-10| 20|
| Chennai | Tamil Nadu| 2020-11| 20|
------------- ------------- ---------- ------------
from the above I want this output, using scala spark (Data frame RDD)
Target DF :
| city| state |year_month|saleCount | last12MonthsSellCount |
------------- ------------- ---------- ------------ ------------------------
| Bangalore | Karnataka | 2020-01| 10| 10 |
| Bangalore | Karnataka | 2020-02| 10| 20 |
| Bangalore | Karnataka | 2021-03| 10| 30 |
| Bangalore | Karnataka | 2021-04| 10| 40 |
| Bangalore | Karnataka | 2021-05| 10| 50 |
| Bangalore | Karnataka | 2021-06| 10| 60 |
| Bangalore | Karnataka | 2021-07| 10| 70 |
| Bangalore | Karnataka | 2021-08| 10| 80 |
| Bangalore | Karnataka | 2021-09| 10| 90 |
| Bangalore | Karnataka | 2021-10| 10|100 |
| Chennai | Tamil Nadu| 2020-05| 20| 20 |
| Chennai | Tamil Nadu| 2020-06| 20| 40 |
| Chennai | Tamil Nadu| 2020-07| 20| 60 |
| Chennai | Tamil Nadu| 2020-08| 20| 80 |
| Chennai | Tamil Nadu| 2020-09| 20|100 |
| Chennai | Tamil Nadu| 2020-10| 20|120 |
| Chennai | Tamil Nadu| 2020-11| 20|140 |
------------- ------------- ---------- ------------ ------------------------
It will be a group inside a group means it has to create a sub-group, sub group with field 'year_month' based on a date range, this range is mainly creating the bottleneck, need help in this trick.
I have tried the below but it is not working
val cityStateMonthlyYearlycount = lastCityCountDF.withColumn("yearMonth", col("year_month")).groupBy(col("city"), col("state"), col("year_month")).agg(sum(when(datediff(col("year_month"), col("yearMonth")).leq(12), col("monthlyCount")).otherwise(0)).as("lastOneYearCount")).filter(col("lastOneYearCount") === 0).select("city", "state", "year_month", "lastOneYearCount")
what should be the correct approach for this ?? I wan to achieve this in scala spark, if possible only in data frame then it will be much better.
CodePudding user response:
Here is how you can do it with the Window
function
val window = Window.partitionBy("city", "state")
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
val resultDF = df.withColumn("sum", sum("saleCount").over(window))
.orderBy("city", "sum")
resultDF.show(false)
Output:
--------- --------- ---------- --------- -----
|city |state |year_month|saleCount|sum |
--------- --------- ---------- --------- -----
|Bangalore|Karnataka|2020-01 |10 |10.0 |
|Bangalore|Karnataka|2020-02 |10 |20.0 |
|Bangalore|Karnataka|2021-03 |10 |30.0 |
|Bangalore|Karnataka|2021-04 |10 |40.0 |
|Bangalore|Karnataka|2021-05 |10 |50.0 |
|Bangalore|Karnataka|2021-06 |10 |60.0 |
|Bangalore|Karnataka|2021-07 |10 |70.0 |
|Bangalore|Karnataka|2021-08 |10 |80.0 |
|Bangalore|Karnataka|2021-09 |10 |90.0 |
|Bangalore|Karnataka|2021-10 |10 |100.0|
|Chennai |TamilNadu|2020-05 |20 |20.0 |
|Chennai |TamilNadu|2020-06 |20 |40.0 |
|Chennai |TamilNadu|2020-07 |20 |60.0 |
|Chennai |TamilNadu|2020-08 |20 |80.0 |
|Chennai |TamilNadu|2020-09 |20 |100.0|
|Chennai |TamilNadu|2020-10 |20 |120.0|
|Chennai |TamilNadu|2020-11 |20 |140.0|
--------- --------- ---------- --------- -----
CodePudding user response:
RDD in common case, as I know, more slow then DataFrame, but it is more powerfull. So it's not so hard to a create custom groupBy
for last 12 months.
koiralo's answer is great, but he does not group by 12 last months.
Data initialization:
val monthlySalesList: Array[Sales] =
"""|Bangalore | Karnataka | 2010-01| 10|
|| Bangalore | Karnataka | 2020-01| 10|
|| Bangalore | Karnataka | 2020-02| 10|
|| Bangalore | Karnataka | 2021-03| 10|
|| Bangalore | Karnataka | 2021-04| 10|
|| Bangalore | Karnataka | 2021-05| 10|
|| Bangalore | Karnataka | 2021-06| 10|
|| Bangalore | Karnataka | 2021-07| 10|
|| Bangalore | Karnataka | 2021-08| 10|
|| Bangalore | Karnataka | 2021-09| 10|
|| Bangalore | Karnataka | 2021-10| 10|
|| Chennai | Tamil Nadu| 2019-05| 20|
|| Chennai | Tamil Nadu| 2020-05| 20|
|| Chennai | Tamil Nadu| 2020-06| 20|
|| Chennai | Tamil Nadu| 2020-07| 20|
|| Chennai | Tamil Nadu| 2020-08| 20|
|| Chennai | Tamil Nadu| 2020-09| 20|
|| Chennai | Tamil Nadu| 2020-10| 20|
|| Chennai | Tamil Nadu| 2020-11| 20|""".stripMargin
.split("\n").map(_.split("\\|").map(_.trim).filter(_.nonEmpty)).map(ar =>
Sales(ar(0), ar(1), ar(2), ar(3).toInt)
)
Code:
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
case class Sales(city: String, state: String, yearMonth: String, count: Int) {
val Array(year, month) = yearMonth.split("-").map(_.toInt)
val mothsSince1970: Int = (year - 1970) * 12 month - 1
}
case class Place(city: String, state: String)
object So73002153 extends App {
val spark: SparkSession = SparkSession.builder.config("spark.master", "local").getOrCreate()
val sc: SparkContext = spark.sparkContext
val monthlySalesDf: RDD[Sales] = sc.parallelize(monthlySalesList)
val salesGroupedByPlace: RDD[(Place, Iterable[Sales])] = monthlySalesDf.groupBy(s => Place(s.state, s.city))
val salesWithSum: RDD[(Place, Array[(String, Int, Int)])] = salesGroupedByPlace.map { case (place, list) =>
val salesForPlace: Array[Sales] = list.toArray.sortBy(_.monthsSince1970)
val salesFor12Months: Array[Int] = Array.ofDim(salesForPlace.length)
var lastFrom = 0
for (i <- 0 until salesForPlace.length) {
val firstIndex: Int = salesForPlace.indexWhere(s => s.monthsSince1970 12 >= salesForPlace(i).monthsSince1970, lastFrom)
lastFrom = firstIndex
var sum = 0
// sum can be optimized, but it's not so easy
for (j <- firstIndex to i) sum = salesForPlace(j).count
salesFor12Months(i) = sum
}
place -> salesForPlace.zip(salesFor12Months).map { case (s, sum) =>
(s.yearMonth, s.count, sum)
}
}
val resultRdd: RDD[(String, String, String, Int, Int)] = salesWithSum.flatMap { case (place, list) =>
list.map(el => (place.city, place.state, el._1, el._2, el._3))
}
resultRdd.foreach(println)
}
Result (not formatted):
(Karnataka,Bangalore,2010-01,10,10)
(Karnataka,Bangalore,2020-01,10,10)
(Karnataka,Bangalore,2020-02,10,20)
(Karnataka,Bangalore,2021-03,10,10)
(Karnataka,Bangalore,2021-04,10,20)
(Karnataka,Bangalore,2021-05,10,30)
(Karnataka,Bangalore,2021-06,10,40)
(Karnataka,Bangalore,2021-07,10,50)
(Karnataka,Bangalore,2021-08,10,60)
(Karnataka,Bangalore,2021-09,10,70)
(Karnataka,Bangalore,2021-10,10,80)
and
(Tamil Nadu,Chennai,2019-05,20,20)
(Tamil Nadu,Chennai,2020-05,20,40)
(Tamil Nadu,Chennai,2020-06,20,40)
(Tamil Nadu,Chennai,2020-07,20,60)
(Tamil Nadu,Chennai,2020-08,20,80)
(Tamil Nadu,Chennai,2020-09,20,100)
(Tamil Nadu,Chennai,2020-10,20,120)
(Tamil Nadu,Chennai,2020-11,20,140)