I have sparkSQL Dataframe which contains unique code, monthdate and number of turnover. I want to loop over each monthdate to get the sum of the turnover in 12 months. For example if the monthdate is January 2022 then the sum will be from January 2021 to January 2022. This is the sample of the dataframe
------ ------------ ----------
| code | monthdate | turnover |
------ ------------ ----------
| AA1 | 2021-01-01 | 10 |
------ ------------ ----------
| AA1 | 2021-02-01 | 20 |
------ ------------ ----------
| AA1 | 2021-03-01 | 30 |
------ ------------ ----------
| AA1 | 2021-04-01 | 40 |
------ ------------ ----------
| AA1 | 2021-05-01 | 50 |
------ ------------ ----------
| AA1 | 2021-06-01 | 60 |
------ ------------ ----------
| AA1 | 2021-07-01 | 70 |
------ ------------ ----------
| AA1 | 2021-08-01 | 80 |
------ ------------ ----------
| AA1 | 2021-09-01 | 90 |
------ ------------ ----------
| AA1 | 2021-10-01 | 100 |
------ ------------ ----------
| AA1 | 2021-11-01 | 101 |
------ ------------ ----------
| AA1 | 2021-12-01 | 102 |
------ ------------ ----------
| AA1 | 2022-01-01 | 103 |
------ ------------ ----------
| AA1 | 2022-02-01 | 104 |
------ ------------ ----------
| AA1 | 2022-03-01 | 105 |
------ ------------ ----------
| AA1 | 2022-04-01 | 106 |
------ ------------ ----------
| AA1 | 2022-05-01 | 107 |
------ ------------ ----------
| AA1 | 2022-06-01 | 108 |
------ ------------ ----------
| AA1 | 2022-07-01 | 109 |
------ ------------ ----------
| AA1 | 2022-08-01 | 110 |
------ ------------ ----------
| AA1 | 2022-09-01 | 111 |
------ ------------ ----------
| AA1 | 2022-10-01 | 112 |
------ ------------ ----------
| AA1 | 2022-11-01 | 113 |
------ ------------ ----------
| AA1 | 2022-12-01 | 114 |
------ ------------ ----------
I'm very new to spark and scala and it's confusing for me to solve this in spark scala way. I have developed the logic but have difficulties to translate it to spark scala. I'm working on cluster mode. Here's my logic.
listkey = df.select("code").distinct.map(r => r(0)).collect())
listkey.foreach(key=>
df.select(*).filter("code==${key}").oderBy("monthdate").foreach(
row=>
var monthdate = row.monthdate
var turnover = row.turnover
var sum = 0
sum = sum turnover
var n = 1
var i = 1
while (n<12){
var monthdate_temp = datetime-i
var turnover_temp =
df.select("turnover").filter("monthdate=${monthdate_temp} and code =${key}").collect()
sum = sum turnover_temp
n =n 1
i = i 1
}
row = row.withColumn("turnover_per_year",sum)
)
)
Any help will be appreciated, thanks in advance
CodePudding user response:
Each row in original dataframe can be expanded to 12 rows with back dates by "explode" function, and result joined with original dataframe, and grouped:
val df = Seq(
("AA1", "2021-01-01", 25),
("AA1", "2022-01-01", 103)
)
.toDF("code", "monthdate", "turnover")
.withColumn("monthdate", to_date($"monthdate", "yyyy-MM-dd"))
val oneYearBackMonths = (0 to 12).map(n => lit(-n))
val explodedWithBackMonths = df
.withColumn("shift", explode(array(oneYearBackMonths: _*)))
.withColumn("rangeMonth",expr("add_months(monthdate, shift)"))
val joinCondition = $"exploded.code" === $"original.code" &&
$"exploded.rangeMonth" === $"original.monthdate"
explodedWithBackMonths.alias("exploded")
.join(df.alias("original"), joinCondition)
.groupBy($"exploded.code", $"exploded.monthdate")
.agg(sum($"original.turnover").alias("oneYearTurnover"))
Result:
---- ---------- ---------------
|code|monthdate |oneYearTurnover|
---- ---------- ---------------
|AA1 |2021-01-01|25 |
|AA1 |2022-01-01|128 |
---- ---------- ---------------
CodePudding user response:
you can use Spark's Window function
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window
val raw = Seq(
("AA1", "2019-01-01", 25),
("AA1", "2021-01-01", 25),
("AA1","2021-08-01",80),
("AA1" ,"2021-09-01" , 90 ),
("AA1", "2022-01-01", 103),
("AA2", "2022-01-01", 10)
).toDF("code", "monthdate", "turnover")
val df = raw.withColumn("monthdate",to_timestamp($"monthdate","yyyy-mm-dd"))
val pw = Window.partitionBy($"code").orderBy($"monthdate".cast("long")).rangeBetween(-(86400*365), 0)
df.withColumn("sum",sum($"turnover").over(pw)).show()
---- ------------------- -------- ---
|code| monthdate|turnover|sum|
---- ------------------- -------- ---
| AA1|2019-01-01 00:01:00| 25| 25|
| AA1|2021-01-01 00:01:00| 25| 25|
| AA1|2021-01-01 00:08:00| 80|105|
| AA1|2021-01-01 00:09:00| 90|195|
| AA1|2022-01-01 00:01:00| 103|298|
| AA2|2022-01-01 00:01:00| 10| 10|
---- ------------------- -------- ---