I have a dataframe consist of 2 columns start and end, both of type date like this:
start | end |
---|---|
2019-07-01 10:01:19.000 | 2019-07-01 10:11:00.000 |
2019-07-01 10:10:05.000 | 2019-07-01 10:40:00.000 |
2019-07-01 10:35:00.000 | 2019-07-01 12:30:00.000 |
2019-07-01 15:20:00.000 | 2019-07-01 15:50:00.000 |
2019-07-01 16:10:00.000 | 2019-07-01 16:35:00.000 |
2019-07-01 16:30:00.000 | 2019-07-01 17:00:00.000 |
I want to a add a new column called group such that if 2 dates intersects they should be in the same group.
So result should be:
start | end | group |
---|---|---|
2019-07-01 10:01:19.000 | 2019-07-01 10:11:00.000 | 1 |
2019-07-01 10:10:05.000 | 2019-07-01 10:40:00.000 | 1 |
2019-07-01 10:35:00.000 | 2019-07-01 12:30:00.000 | 1 |
2019-07-01 15:20:00.000 | 2019-07-01 15:50:00.000 | 2 |
2019-07-01 16:10:00.000 | 2019-07-01 16:35:00.000 | 3 |
2019-07-01 16:30:00.000 | 2019-07-01 17:00:00.000 | 3 |
I wasnt able to determine if 2 dates intersect or not, also the dates are randomly positioned.
Any help or hint will please
CodePudding user response:
You can achieve this using window function in spark, it will help you sort the data and get the previous row value which will do the trick, I thought it's better to add comments in the code to explain it:
// create the dataframe from your example:
import spark.implicits._
val df = Seq(
("2019-07-01 10:01:19.000", "2019-07-01 10:11:00.000"),
("2019-07-01 10:10:05.000", "2019-07-01 10:40:00.000"),
("2019-07-01 10:35:00.000", "2019-07-01 12:30:00.000"),
("2019-07-01 15:20:00.000", "2019-07-01 15:50:00.000"),
("2019-07-01 16:10:00.000", "2019-07-01 16:35:00.000"),
("2019-07-01 16:30:00.000", "2019-07-01 17:00:00.000"),
).toDF("start", "end")
// A window to sort date by start ascending then end ascending, to get the end of the previous row to check if there's an intersection
val w = Window.orderBy("start", "end")
// transform column from string type to timestamp type
df.select(to_timestamp(col("start")).as("start"), to_timestamp(col("end")).as("end"))
// prev_end column contains the value of the end column of the previous row
.withColumn("prev_end", lag("end", 1, null).over(w))
// create column intersection with value 0 if there's intersection and 1 otherwhise
.withColumn("intersection", when(col("prev_end").isNull.or(col("prev_end").geq(col("start")).and(col("prev_end").leq(col("end")))), 0).otherwise(1))
// The key element to this solution: prefix sum over the window to make sure we have the right values of each group
.withColumn("group", functions.sum("intersection").over(w.rowsBetween(Window.unboundedPreceding, Window.currentRow)))
.drop("prev_end", "intersection")
.show(false)
------------------- ------------------- -----
|start |end |group|
------------------- ------------------- -----
|2019-07-01 10:01:19|2019-07-01 10:11:00|0 |
|2019-07-01 10:10:05|2019-07-01 10:40:00|0 |
|2019-07-01 10:35:00|2019-07-01 12:30:00|0 |
|2019-07-01 15:20:00|2019-07-01 15:50:00|1 |
|2019-07-01 16:10:00|2019-07-01 16:35:00|2 |
|2019-07-01 16:30:00|2019-07-01 17:00:00|2 |
------------------- ------------------- -----
I Hope this helps.