I have a spark df that roughly looks like this:
company ID quarter metric
12 31-12-2019 54.3
12 30-09-2019 48.2
12 30-06-2019 32.3
12 30-03-2019 54.3
23 31-12-2018 54.3
23 30-09-2018 48.2
23 30-06-2018 32.3
23 30-03-2018 54.3
45 31-12-2021 54.3
45 30-09-2021 48.2
45 30-06-2021 32.3
45 30-03-2021 54.3
45 31-12-2021 54.3
45 30-09-2020 48.2
45 30-06-2020 32.3
45 30-03-2020 54.3
.. .. ..
For each quarter row for each company ID I need to compute an annual value from the following quarters i.e for company ID = 45 and quarter = 30-06-2020 annual value would be equal to:
30-03-2021 54.3
31-12-2020 54.3
30-09-2020 48.2
30-06-2020 32.3
--------
189,1
Result:
company ID quarter metric annual
12 31-12-2019 54.3
12 30-09-2019 48.2
12 30-06-2019 32.3
12 30-03-2019 54.3
23 31-12-2018 54.3
23 30-09-2018 48.2
23 30-06-2018 32.3
23 30-03-2018 54.3
45 31-12-2021 54.3
45 30-09-2021 48.2
45 30-06-2021 32.3
45 30-03-2021 54.3
45 31-12-2021 54.3
45 30-09-2020 48.2
45 30-06-2020 32.3 **189,1**
45 30-03-2020 54.3
.. .. ..
In pandas I would probably groupby by entity ID and then would try to compute a column based on indices or something like that. What would be the most effective way to do it in Spark/Python?
CodePudding user response:
Date can be converted to days since 1970, and Window function with range 366 days back with "sum" can be used, on Scala:
val df = Seq(
(12, "31-12-2019", 54.3),
(12, "30-09-2019", 48.2),
(12, "30-06-2019", 32.3),
(12, "30-03-2019", 54.3),
(23, "31-12-2018", 54.3),
(23, "30-09-2018", 48.2),
(23, "30-06-2018", 32.3),
(23, "30-03-2018", 54.3),
(45, "31-12-2021", 54.3),
(45, "30-09-2021", 48.2),
(45, "30-06-2021", 32.3),
(45, "30-03-2021", 54.3),
(45, "31-12-2021", 54.3),
(45, "30-09-2020", 48.2),
(45, "30-06-2020", 32.3),
(45, "30-03-2020", 54.3),
)
.toDF("company ID", "quarter", "metric")
val companyIdWindow = Window.partitionBy("company ID").rangeBetween(-366, Window.currentRow).orderBy("days")
import java.util.concurrent.TimeUnit
val secondsInDay = TimeUnit.DAYS.toSeconds(1)
df
.withColumn("days", unix_timestamp($"quarter", "dd-MM-yyyy") / secondsInDay)
.withColumn("annual", sum("metric").over(companyIdWindow))
.drop("days")
Result:
---------- ---------- ------ ------------------
|company ID|quarter |metric|annual |
---------- ---------- ------ ------------------
|23 |30-03-2018|54.3 |54.3 |
|23 |30-06-2018|32.3 |86.6 |
|23 |30-09-2018|48.2 |134.8 |
|23 |31-12-2018|54.3 |189.10000000000002|
|45 |30-03-2020|54.3 |54.3 |
|45 |30-06-2020|32.3 |86.6 |
|45 |30-09-2020|48.2 |134.8 |
|45 |30-03-2021|54.3 |189.10000000000002|
|45 |30-06-2021|32.3 |167.10000000000002|
|45 |30-09-2021|48.2 |183.0 |
|45 |31-12-2021|54.3 |243.40000000000003|
|45 |31-12-2021|54.3 |243.40000000000003|
|12 |30-03-2019|54.3 |54.3 |
|12 |30-06-2019|32.3 |86.6 |
|12 |30-09-2019|48.2 |134.8 |
|12 |31-12-2019|54.3 |189.10000000000002|
---------- ---------- ------ ------------------