Home > OS >  operating on a row in a pyspark column based on previous row
operating on a row in a pyspark column based on previous row

Time:10-01

I have a pyspark dataframe where I want to update the drift_MS column based on the value of another column _MS. However the math I would apply will differ based on the condition of _MS

Dataframe:

|SEQ_ID |TIME_STAMP             |_MS               |
 ------- ----------------------- ------------------ 
|3879826|2021-07-29 11:24:20.525|NaN               |
|3879826|2021-07-29 11:25:56.934|21.262409581399556|
|3879826|2021-07-29 11:27:43.264|27.247600203353613|
|3879826|2021-07-29 11:29:27.613|18.13528511851038 |
|3879826|2021-07-29 11:31:10.512|2.520896614376871 |
|3879826|2021-07-29 11:32:54.252|2.7081931585605541|
|3879826|2021-07-29 11:34:36.995|2.9832290627235505|
|3879826|2021-07-29 11:36:19.128|13.011968111650264|
|3879826|2021-07-29 11:38:10.919|17.762006254598797|
|3879826|2021-07-29 11:40:01.929|1.9661930950977457|

when _MS is >=3 and when the previous _MS is lesser than current _MS I want to increment the drift_MS by 100. But if _MS is <3 and previous _MS < current _MS I want to increment drift_MS by 1, otherwise keep the previous drift_MS value

Expected output:

|SEQ_ID |TIME_STAMP             |_MS               |drift_MS|
 ------- ----------------------- ------------------ -------- 
|3879826|2021-07-29 11:24:20.525|NaN               |0       |
|3879826|2021-07-29 11:25:56.934|21.262409581399556|0       |
|3879826|2021-07-29 11:27:43.264|27.247600203353613|100     |
|3879826|2021-07-29 11:29:27.613|18.13528511851038 |100     |
|3879826|2021-07-29 11:31:10.512|2.520896614376871 |100     |
|3879826|2021-07-29 11:32:54.252|2.7081931585605541|101     |
|3879826|2021-07-29 11:34:36.995|2.9832290627235505|102     |
|3879826|2021-07-29 11:36:19.128|13.011968111650264|202     |
|3879826|2021-07-29 11:38:10.919|17.762006254598797|302     |
|3879826|2021-07-29 11:40:01.929|1.9661930950977457|302     |

I tried the following code:

w1=Window.partitionBy('SEQ_ID').orderBy(col('TIME_STAMP').asc())

prev_MS = (f.lag(col('_MS'),1).over(w1))
prev_drift_MS = (f.lag(col('drift_MS'),1).over(w1))

df2=df.withColumn('drift_MS', when((col('_MS') < 3) & (prev_MS < col('_MS')), prev_drift_MS 1)\
                               .when((col('_MS') >= 3) & (prev_MS < col('_MS')), prev_drift_MS 100).otherwise(prev_drift_MS 0))

But the drift_MS column is either 100 or 1. What am I doing wrong?

CodePudding user response:

try:

df.withColumn('drift_MS', 
  f.sum(
    when((col('_MS') < 3) & (prev_MS < col('_MS')), 1)
    .when((col('_MS') >= 3) & (prev_MS < col('_MS')), 100)
    .otherwise(0)
 ).over(w1))
  • Related