Home > Back-end >  Creating new column with values of looping through SparkSQL Dataframe
Creating new column with values of looping through SparkSQL Dataframe

Time:04-12

I have sparkSQL Dataframe which contains unique code, monthdate and number of turnover. I want to loop over each monthdate to get the sum of the turnover in 12 months. For example if the monthdate is January 2022 then the sum will be from January 2021 to January 2022. This is the sample of the dataframe

 ------ ------------ ---------- 
| code | monthdate  | turnover |
 ------ ------------ ---------- 
| AA1  | 2021-01-01 | 10       |
 ------ ------------ ---------- 
| AA1  | 2021-02-01 | 20       |
 ------ ------------ ---------- 
| AA1  | 2021-03-01 | 30       |
 ------ ------------ ---------- 
| AA1  | 2021-04-01 | 40       |
 ------ ------------ ---------- 
| AA1  | 2021-05-01 | 50       |
 ------ ------------ ---------- 
| AA1  | 2021-06-01 | 60       |
 ------ ------------ ---------- 
| AA1  | 2021-07-01 | 70       |
 ------ ------------ ---------- 
| AA1  | 2021-08-01 | 80       |
 ------ ------------ ---------- 
| AA1  | 2021-09-01 | 90       |
 ------ ------------ ---------- 
| AA1  | 2021-10-01 | 100      |
 ------ ------------ ---------- 
| AA1  | 2021-11-01 | 101      |
 ------ ------------ ---------- 
| AA1  | 2021-12-01 | 102      |
 ------ ------------ ---------- 
| AA1  | 2022-01-01 | 103      |
 ------ ------------ ---------- 
| AA1  | 2022-02-01 | 104      |
 ------ ------------ ---------- 
| AA1  | 2022-03-01 | 105      |
 ------ ------------ ---------- 
| AA1  | 2022-04-01 | 106      |
 ------ ------------ ---------- 
| AA1  | 2022-05-01 | 107      |
 ------ ------------ ---------- 
| AA1  | 2022-06-01 | 108      |
 ------ ------------ ---------- 
| AA1  | 2022-07-01 | 109      |
 ------ ------------ ---------- 
| AA1  | 2022-08-01 | 110      |
 ------ ------------ ---------- 
| AA1  | 2022-09-01 | 111      |
 ------ ------------ ---------- 
| AA1  | 2022-10-01 | 112      |
 ------ ------------ ---------- 
| AA1  | 2022-11-01 | 113      |
 ------ ------------ ---------- 
| AA1  | 2022-12-01 | 114      |
 ------ ------------ ---------- 

I'm very new to spark and scala and it's confusing for me to solve this in spark scala way. I have developed the logic but have difficulties to translate it to spark scala. I'm working on cluster mode. Here's my logic.

listkey = df.select("code").distinct.map(r => r(0)).collect())
listkey.foreach(key=>
     df.select(*).filter("code==${key}").oderBy("monthdate").foreach(
        row=>
        var monthdate = row.monthdate
        var turnover = row.turnover
        var sum = 0
        sum = sum   turnover
        var n = 1
        var i = 1
        while (n<12){
               var monthdate_temp = datetime-i
               var turnover_temp =
               df.select("turnover").filter("monthdate=${monthdate_temp} and code =${key}").collect()
               sum = sum turnover_temp
               n =n 1
               i = i 1
         }
         row = row.withColumn("turnover_per_year",sum)
    )
)

Any help will be appreciated, thanks in advance

CodePudding user response:

Each row in original dataframe can be expanded to 12 rows with back dates by "explode" function, and result joined with original dataframe, and grouped:

val df = Seq(
  ("AA1", "2021-01-01", 25),
  ("AA1", "2022-01-01", 103)
)
  .toDF("code", "monthdate", "turnover")
  .withColumn("monthdate", to_date($"monthdate", "yyyy-MM-dd"))

val oneYearBackMonths = (0 to 12).map(n => lit(-n))

val explodedWithBackMonths = df
  .withColumn("shift", explode(array(oneYearBackMonths: _*)))
  .withColumn("rangeMonth",expr("add_months(monthdate, shift)"))

val joinCondition = $"exploded.code" === $"original.code" &&
  $"exploded.rangeMonth" === $"original.monthdate"

explodedWithBackMonths.alias("exploded")
  .join(df.alias("original"), joinCondition)
  .groupBy($"exploded.code", $"exploded.monthdate")
  .agg(sum($"original.turnover").alias("oneYearTurnover"))

Result:

 ---- ---------- --------------- 
|code|monthdate |oneYearTurnover|
 ---- ---------- --------------- 
|AA1 |2021-01-01|25             |
|AA1 |2022-01-01|128            |
 ---- ---------- --------------- 

CodePudding user response:

you can use Spark's Window function

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window

val raw = Seq(
  ("AA1", "2019-01-01", 25),
  ("AA1", "2021-01-01", 25),
  ("AA1","2021-08-01",80),
  ("AA1" ,"2021-09-01" , 90 ),
  ("AA1", "2022-01-01", 103),
  ("AA2", "2022-01-01", 10)
).toDF("code", "monthdate", "turnover")
val df = raw.withColumn("monthdate",to_timestamp($"monthdate","yyyy-mm-dd"))
val pw = Window.partitionBy($"code").orderBy($"monthdate".cast("long")).rangeBetween(-(86400*365), 0)
df.withColumn("sum",sum($"turnover").over(pw)).show()

 ---- ------------------- -------- --- 
|code|          monthdate|turnover|sum|
 ---- ------------------- -------- --- 
| AA1|2019-01-01 00:01:00|      25| 25|
| AA1|2021-01-01 00:01:00|      25| 25|
| AA1|2021-01-01 00:08:00|      80|105|
| AA1|2021-01-01 00:09:00|      90|195|
| AA1|2022-01-01 00:01:00|     103|298|
| AA2|2022-01-01 00:01:00|      10| 10|
 ---- ------------------- -------- --- 
  • Related