Home > front end >  Create a column in Spark df based on several values in different column
Create a column in Spark df based on several values in different column

Time:05-11

I have a spark df that roughly looks like this:

company ID  quarter     metric
  12       31-12-2019    54.3     
  12       30-09-2019    48.2
  12       30-06-2019    32.3
  12       30-03-2019    54.3
  23       31-12-2018    54.3
  23       30-09-2018    48.2
  23       30-06-2018    32.3
  23       30-03-2018    54.3
  45       31-12-2021    54.3
  45       30-09-2021    48.2
  45       30-06-2021    32.3
  45       30-03-2021    54.3
  45       31-12-2021    54.3
  45       30-09-2020    48.2
  45       30-06-2020    32.3
  45       30-03-2020    54.3
  ..           ..         ..

For each quarter row for each company ID I need to compute an annual value from the following quarters i.e for company ID = 45 and quarter = 30-06-2020 annual value would be equal to:

30-03-2021    54.3
31-12-2020    54.3
30-09-2020    48.2
30-06-2020    32.3
            --------
              189,1

Result:

   company ID  quarter     metric   annual
      12       31-12-2019    54.3     
      12       30-09-2019    48.2
      12       30-06-2019    32.3
      12       30-03-2019    54.3
      23       31-12-2018    54.3
      23       30-09-2018    48.2
      23       30-06-2018    32.3
      23       30-03-2018    54.3
      45       31-12-2021    54.3
      45       30-09-2021    48.2
      45       30-06-2021    32.3
      45       30-03-2021    54.3
      45       31-12-2021    54.3
      45       30-09-2020    48.2
      45       30-06-2020    32.3   **189,1**
      45       30-03-2020    54.3
      ..           ..         ..

In pandas I would probably groupby by entity ID and then would try to compute a column based on indices or something like that. What would be the most effective way to do it in Spark/Python?

CodePudding user response:

Date can be converted to days since 1970, and Window function with range 366 days back with "sum" can be used, on Scala:

val df = Seq(
  (12, "31-12-2019", 54.3),
  (12, "30-09-2019", 48.2),
  (12, "30-06-2019", 32.3),
  (12, "30-03-2019", 54.3),
  (23, "31-12-2018", 54.3),
  (23, "30-09-2018", 48.2),
  (23, "30-06-2018", 32.3),
  (23, "30-03-2018", 54.3),
  (45, "31-12-2021", 54.3),
  (45, "30-09-2021", 48.2),
  (45, "30-06-2021", 32.3),
  (45, "30-03-2021", 54.3),
  (45, "31-12-2021", 54.3),
  (45, "30-09-2020", 48.2),
  (45, "30-06-2020", 32.3),
  (45, "30-03-2020", 54.3),
)
  .toDF("company ID", "quarter", "metric")

val companyIdWindow = Window.partitionBy("company ID").rangeBetween(-366, Window.currentRow).orderBy("days")

import java.util.concurrent.TimeUnit
val secondsInDay = TimeUnit.DAYS.toSeconds(1)
df
  .withColumn("days", unix_timestamp($"quarter", "dd-MM-yyyy") / secondsInDay)
  .withColumn("annual", sum("metric").over(companyIdWindow))
  .drop("days")

Result:

 ---------- ---------- ------ ------------------ 
|company ID|quarter   |metric|annual            |
 ---------- ---------- ------ ------------------ 
|23        |30-03-2018|54.3  |54.3              |
|23        |30-06-2018|32.3  |86.6              |
|23        |30-09-2018|48.2  |134.8             |
|23        |31-12-2018|54.3  |189.10000000000002|
|45        |30-03-2020|54.3  |54.3              |
|45        |30-06-2020|32.3  |86.6              |
|45        |30-09-2020|48.2  |134.8             |
|45        |30-03-2021|54.3  |189.10000000000002|
|45        |30-06-2021|32.3  |167.10000000000002|
|45        |30-09-2021|48.2  |183.0             |
|45        |31-12-2021|54.3  |243.40000000000003|
|45        |31-12-2021|54.3  |243.40000000000003|
|12        |30-03-2019|54.3  |54.3              |
|12        |30-06-2019|32.3  |86.6              |
|12        |30-09-2019|48.2  |134.8             |
|12        |31-12-2019|54.3  |189.10000000000002|
 ---------- ---------- ------ ------------------ 
  • Related