I have a dataframe contains billion records and which I want to combine identical rows into one rows based on their effective_start and effective_end date
key1 | key2 | start | end |
---|---|---|---|
k11 | k2 | 2000-01-01 | 2000-02-01 |
k11 | k2 | 2000-02-01 | 2000-03-01 |
k11 | k2 | 2000-03-01 | 2000-04-01 |
k11 | k2 | 2000-04-01 | 2000-05-01 |
k11 | k2 | 2000-05-01 | 2000-06-01 |
k11 | k2 | 2000-08-01 | 2000-09-01 |
k11 | k2 | 2000-09-01 | 2000-10-01 |
k22 | k2 | 2000-01-01 | 2000-02-01 |
k22 | k2 | 2000-02-01 | 2000-03-01 |
k22 | k3 | 2000-03-01 | 2000-04-01 |
k22 | k3 | 2000-04-01 | 2000-05-01 |
k22 | k3 | 2000-05-01 | 2000-06-01 |
if group by key1/key2 then sort by start, you can see there are three groups
- key11/key2,
- key22/key2,
- key22/key3,
If the previous row's end equals to next row's start, then the same group can be combined, otherwise it is not combined.
The expected output is
key1 | key2 | start | end |
---|---|---|---|
k11 | k2 | 2000-01-01 | 2000-06-01 |
k11 | k2 | 2000-08-01 | 2000-10-01 |
k22 | k2 | 2000-01-01 | 2000-03-01 |
k22 | k3 | 2000-03-01 | 2000-06-01 |
How do I achieve this? Thanks in advance.
CodePudding user response:
The logic is:
- Append "end" column shifted by one record. Since, spark has distributed architecture, there is no notion of "position or index" of a record. This is done with help of Window function and columns with which to order by.
- Compute the diff between "start" and previous record's "end" column.
- Group the records until this diff is zero. To identify such diff groups, compute cumulative sum.
- Finally group by groups identified above. The start columns is min of the group and end is max.
df = spark.createDataFrame(data=[["k11","k2","2000-01-01","2000-02-01"],["k11","k2","2000-02-01","2000-03-01"],["k11","k2","2000-03-01","2000-04-01"],["k11","k2","2000-04-01","2000-05-01"],["k11","k2","2000-05-01","2000-06-01"],["k11","k2","2000-08-01","2000-09-01"],["k11","k2","2000-09-01","2000-10-01"],["k22","k2","2000-01-01","2000-02-01"],["k22","k2","2000-02-01","2000-03-01"],["k22","k3","2000-03-01","2000-04-01"],["k22","k3","2000-04-01","2000-05-01"],["k22","k3","2000-05-01","2000-06-01"]], schema=["key1","key2","start","end"])
from pyspark.sql.window import Window
w = Window.partitionBy("key1", "key2").orderBy("start")
df = df.withColumn("start", F.to_date("start", format="yyyy-MM-dd")).withColumn("end", F.to_date("end", format="yyyy-MM-dd"))
df = df.withColumn("prev_end", F.lag("end", offset=1).over(w))
df = df.withColumn("date_diff", F.datediff(F.col("start"), F.col("prev_end")))
df = df.withColumn("is_continuous", F.when(F.col("date_diff").isNull() | (F.col("date_diff") > 0), F.lit(1)).otherwise(F.lit(0)))
df = df.withColumn("cumsum", F.sum(F.col("is_continuous")).over(w))
df = df.groupBy("key1", "key2", "cumsum").agg(F.min("start").alias("start"), F.max("end").alias("end")).drop("cumsum")
[Out]:
---- ---- ---------- ----------
|key1|key2|start |end |
---- ---- ---------- ----------
|k11 |k2 |2000-01-01|2000-06-01|
|k11 |k2 |2000-08-01|2000-10-01|
|k22 |k2 |2000-01-01|2000-03-01|
|k22 |k3 |2000-03-01|2000-06-01|
---- ---- ---------- ----------