I have a crimes dataset and need to plot a monthly time series line chart of all crimes for the last 3 years (2019,2020,2021). My approach is to create a new dataframe where the count per month is the total count of incidents from 2019-202 and then plot that dataframe.
An example would be
I know I would have to use an aggregated sum function but I'm new to this and unsure of an approach.
Please respond if you need more information!
I attempted to use multiple filters as in df.filter(...) along with .agg() but still am not able to get the syntax or approach correct.
CodePudding user response:
Due to the distributed architecture of spark, the dataset rows are split across different worker nodes and partitions. Operations where computation of next row depends on output of previous row are trickier in spark.
First, partition the data by group. In your case, there is no such group, so introduce a dummy key with a constant value for all rows. Then partition by this key and order by required fields. In your case, first order by "year", then by "month". Now perform sum over this window to get running total:
import pyspark.sql.functions as F
from pyspark.sql import Window
df = spark.createDataFrame(data=[[2021, 12, 50],[2021, 11, 50],[2022, 2, 50],[2022, 1, 50],[2022, 10, 50]], schema=["year","month","count"])
df = df.withColumn("dummy_key", F.lit("0"))
w = Window.partitionBy("dummy_key").orderBy("year", "month")
df = df.withColumn("running_total", F.sum("count").over(w))
[Out]:
---- ----- ----- --------- -------------
|year|month|count|dummy_key|running_total|
---- ----- ----- --------- -------------
|2021| 11| 50| 0| 50|
|2021| 12| 50| 0| 100|
|2022| 1| 50| 0| 150|
|2022| 2| 50| 0| 200|
|2022| 10| 50| 0| 250|
---- ----- ----- --------- -------------