Home > Software engineering >  Using lag function in Spark Scala to bring values from another column
Using lag function in Spark Scala to bring values from another column

Time:09-02

I have a dataframe that is such as the following, but that has several different items in the column "person".

val df_beginning = Seq(("2022-06-06", "person1", 1),
             ("2022-06-13", "person1", 1),
             ("2022-06-20", "person1", 1),
             ("2022-06-27", "person1", 0),
             ("2022-07-04", "person1", 0),
             ("2022-07-11", "person1", 1),
             ("2022-07-18", "person1", 1),
             ("2022-07-25", "person1", 0),
             ("2022-08-01", "person1", 0),
             ("2022-08-08", "person1", 1),
             ("2022-08-15", "person1", 1),
             ("2022-08-22", "person1", 1),
             ("2022-08-29", "person1", 1))
.toDF("week", "person", "person_active_flag")
.orderBy($"week")

enter image description here

I want to create a new column that will have the week in which that chain of person_active_flag with value 1 started. In the end, it would look something like this:

val df_beginning = Seq(("2022-06-06", "person1", 1, "2022-06-06"),
             ("2022-06-13", "person1", 1, "2022-06-06"),
             ("2022-06-20", "person1", 1, "2022-06-06"),
             ("2022-06-27", "person1", 0, "0"),
             ("2022-07-04", "person1", 0, "0"),
             ("2022-07-11", "person1", 1, "2022-07-11"),
             ("2022-07-18", "person1", 1, "2022-07-11"),
             ("2022-07-25", "person1", 0, "0"),
             ("2022-08-01", "person1", 0, "0"),
             ("2022-08-08", "person1", 1, "2022-08-08"),
             ("2022-08-15", "person1", 1, "2022-08-08"),
             ("2022-08-22", "person1", 1, "2022-08-08"),
             ("2022-08-29", "person1", 1, "2022-08-08"))
.toDF("week", "person", "person_active_flag", "chain_beginning")
.orderBy($"week")

enter image description here

But I am not being able to do it. I have tried some variations of the code below, but it doesn't give me the right answer. Can someone show me to do this, please?

val w = Window.partitionBy($"person").orderBy($"week".asc)

df_beginning
.withColumn("beginning_chain", 
    when($"person_active_flag" === 1 && (lag($"person_active_flag", 1).over(w) === 0 || lag($"person_active_flag", 1).over(w).isNull), 1).otherwise(0)
)

.withColumn("first_week", when($"beginning_chain" === 1, $"week"))

.withColumn("beginning_chain_week", 
    when($"person_active_flag" === 1 && lag($"person_active_flag", 1).over(w).isNull, $"first_week")
   .when($"person_active_flag" === 1 && lag($"person_active_flag", 1).over(w) === 0, $"first_week")
   .when($"person_active_flag" === 1 && lag($"person_active_flag", 1).over(w) === 1, lag($"first_week", 1).over(w))
//    .when($"person_active_flag" === 1 && lag($"person_active_flag", 1).over(w) === 1, "test")
   .otherwise(0)
)
.d

enter image description here

CodePudding user response:

  • Use lag function to add helper column switch_flag to show you when the flag changed from previous week
  • Then mark week_beginning only for rows where it switched from 0 to 1
  • Finally using last(col, ignoreNulls = true) extend week_beginning to all rows where person is active

Final query:

val window = Window.partitionBy($"person").orderBy($"week")
df_beginning
  .withColumn("switch_flag", $"person_active_flag" - coalesce(lag($"person_active_flag", 1).over(window), lit(0)))
  .withColumn("week_beginning", when($"switch_flag" === 1, $"week"))
  .withColumn("week_beginning", when($"person_active_flag" === 1, last($"week_beginning", true).over(window)))
  .show

CodePudding user response:

Or using cumulative count:

.withColumn("sum", expr("sum(person_active_flag) over (partition by person order by week)"))
.withColumn("sum2", when(col("person_active_flag").equalTo(1), 0).otherwise(col("sum")))
.withColumn("sum3", expr("sum(sum2) over (partition by person order by week)"))
.withColumn("sum4", when(col("person_active_flag").equalTo(0), -1).otherwise(col("sum3")))

.withColumn("date", expr("first_value(week) over (partition by person,sum4 order by week)"))
.withColumn("date", when(col("person_active_flag").equalTo(0), 0).otherwise(col("date")))

Note: leaving many variables so you can actually understand the process; the intermediate results:

 ---------- ------- ------------------ --- ---- ---- ---- ---------- 
|week      |person |person_active_flag|sum|sum2|sum3|sum4|date      |
 ---------- ------- ------------------ --- ---- ---- ---- ---------- 
|2022-06-06|person1|1                 |1  |0   |0   |0   |2022-06-06|
|2022-06-13|person1|1                 |2  |0   |0   |0   |2022-06-06|
|2022-06-20|person1|1                 |3  |0   |0   |0   |2022-06-06|
|2022-06-27|person1|0                 |3  |3   |3   |-1  |0         |
|2022-07-04|person1|0                 |3  |3   |6   |-1  |0         |
|2022-07-11|person1|1                 |4  |0   |6   |6   |2022-07-11|
|2022-07-18|person1|1                 |5  |0   |6   |6   |2022-07-11|
|2022-07-25|person1|0                 |5  |5   |11  |-1  |0         |
|2022-08-01|person1|0                 |5  |5   |16  |-1  |0         |
|2022-08-08|person1|1                 |6  |0   |16  |16  |2022-08-08|
|2022-08-15|person1|1                 |7  |0   |16  |16  |2022-08-08|
|2022-08-22|person1|1                 |8  |0   |16  |16  |2022-08-08|
|2022-08-29|person1|1                 |9  |0   |16  |16  |2022-08-08|
 ---------- ------- ------------------ --- ---- ---- ---- ---------- 

Good luck!

  • Related