Home > Software design >  Apache spark scala how to detect dates intersections and associate groupe
Apache spark scala how to detect dates intersections and associate groupe

Time:12-14

I have a dataframe consist of 2 columns start and end, both of type date like this:

start end
2019-07-01 10:01:19.000 2019-07-01 10:11:00.000
2019-07-01 10:10:05.000 2019-07-01 10:40:00.000
2019-07-01 10:35:00.000 2019-07-01 12:30:00.000
2019-07-01 15:20:00.000 2019-07-01 15:50:00.000
2019-07-01 16:10:00.000 2019-07-01 16:35:00.000
2019-07-01 16:30:00.000 2019-07-01 17:00:00.000

I want to a add a new column called group such that if 2 dates intersects they should be in the same group.

So result should be:

start end group
2019-07-01 10:01:19.000 2019-07-01 10:11:00.000 1
2019-07-01 10:10:05.000 2019-07-01 10:40:00.000 1
2019-07-01 10:35:00.000 2019-07-01 12:30:00.000 1
2019-07-01 15:20:00.000 2019-07-01 15:50:00.000 2
2019-07-01 16:10:00.000 2019-07-01 16:35:00.000 3
2019-07-01 16:30:00.000 2019-07-01 17:00:00.000 3

I wasnt able to determine if 2 dates intersect or not, also the dates are randomly positioned.

Any help or hint will please

CodePudding user response:

You can achieve this using window function in spark, it will help you sort the data and get the previous row value which will do the trick, I thought it's better to add comments in the code to explain it:

// create the dataframe from your example:
import spark.implicits._
val df = Seq(
  ("2019-07-01 10:01:19.000", "2019-07-01 10:11:00.000"),
  ("2019-07-01 10:10:05.000", "2019-07-01 10:40:00.000"),
  ("2019-07-01 10:35:00.000", "2019-07-01 12:30:00.000"),
  ("2019-07-01 15:20:00.000", "2019-07-01 15:50:00.000"),
  ("2019-07-01 16:10:00.000", "2019-07-01 16:35:00.000"),
  ("2019-07-01 16:30:00.000", "2019-07-01 17:00:00.000"),
).toDF("start", "end")

// A window to sort date by start ascending then end ascending, to get the end of the previous row to check if there's an intersection
val w = Window.orderBy("start", "end")

// transform column from string type to timestamp type
df.select(to_timestamp(col("start")).as("start"), to_timestamp(col("end")).as("end"))
  // prev_end column contains the value of the end column of the previous row
  .withColumn("prev_end", lag("end", 1, null).over(w))
  // create column intersection with value 0 if there's intersection and 1 otherwhise
  .withColumn("intersection", when(col("prev_end").isNull.or(col("prev_end").geq(col("start")).and(col("prev_end").leq(col("end")))), 0).otherwise(1))
  // The key element to this solution: prefix sum over the window to make sure we have the right values of each group
  .withColumn("group", functions.sum("intersection").over(w.rowsBetween(Window.unboundedPreceding, Window.currentRow)))
  .drop("prev_end", "intersection")
  .show(false)

 ------------------- ------------------- ----- 
|start              |end                |group|
 ------------------- ------------------- ----- 
|2019-07-01 10:01:19|2019-07-01 10:11:00|0    |
|2019-07-01 10:10:05|2019-07-01 10:40:00|0    |
|2019-07-01 10:35:00|2019-07-01 12:30:00|0    |
|2019-07-01 15:20:00|2019-07-01 15:50:00|1    |
|2019-07-01 16:10:00|2019-07-01 16:35:00|2    |
|2019-07-01 16:30:00|2019-07-01 17:00:00|2    |
 ------------------- ------------------- ----- 

I Hope this helps.

  • Related