Home > Net >  Cumulative count specified status by ID
Cumulative count specified status by ID

Time:07-12

I have a dataframe in PySpark, similar to this:

 --- ------ ----------- 
|id |status|date       |
 --- ------ ----------- 
|1  |1     |01-01-2022 |
|1  |0     |02-01-2022 |
|1  |0     |03-01-2022 |
|1  |0     |04-01-2022 |
|1  |1     |05-01-2022 |
|1  |0     |06-01-2022 |
|2  |1     |01-01-2022 |
|2  |0     |02-01-2022 |
|2  |0     |03-01-2022 |
|2  |1     |04-01-2022 |
|2  |0     |05-01-2022 |
 --- ------ ----------- 

Where I have customer IDs and their daily status.

I would like to count how many days they are in sequence in status 0.

Expected output:

 --- ------ ----------- ------------ 
|id |status|date       |count_status|
 --- ------ ----------- ------------ 
|1  |1     |01-01-2022 | 0          |
|1  |0     |02-01-2022 | 1          |
|1  |0     |03-01-2022 | 2          |
|1  |0     |04-01-2022 | 3          |
|1  |1     |05-01-2022 | 0          |
|1  |0     |06-01-2022 | 1          |
|2  |1     |01-01-2022 | 0          |
|2  |0     |02-01-2022 | 1          |
|2  |0     |03-01-2022 | 2          |
|2  |1     |04-01-2022 | 0          |
|2  |0     |05-01-2022 | 1          |
 --- ------ ----------- ----------- 

In python, I made this code:

df['count_status'] = np.where(df['status'] == 0, 
                             df.groupby(['id', 
                                        (df['status'] != df['status'].shift(1)).cumsum()]).cumcount() 1,
                                  0)

I recently started learning PySpark, and I can't rewrite the previous code. I tried to do a join of separate tables, but without success.

I saw some solutions using the window function, but I couldn't understand how to apply the window function with this lag from the previous day.

CodePudding user response:

You can do it using 2 windows. First, identify groups of equal status, then add numbering for rows in those groups.

Input:

from pyspark.sql import functions as F, Window as W

df = spark.createDataFrame(
    [(1, 1, '01-01-2022'),
     (1, 0, '02-01-2022'),
     (1, 0, '03-01-2022'),
     (1, 0, '04-01-2022'),
     (1, 1, '05-01-2022'),
     (1, 0, '06-01-2022'),
     (2, 1, '01-01-2022'),
     (2, 0, '02-01-2022'),
     (2, 0, '03-01-2022'),
     (2, 1, '04-01-2022'),
     (2, 0, '05-01-2022')],
    ['id', 'status', 'date'])

Script:

true_date = F.to_date('date', 'dd-MM-yyyy')

w1 = W.partitionBy('id').orderBy(true_date)
w2 = W.partitionBy('id', 'group').orderBy(true_date)

df = df.withColumn('group', F.sum('status').over(w1))
df = df.withColumn('count_status', F.row_number().over(w2) - 1)
df = df.drop('group')

df.show()
#  --- ------ ---------- ------------ 
# | id|status|      date|count_status|
#  --- ------ ---------- ------------ 
# |  1|     1|01-01-2022|           0|
# |  1|     0|02-01-2022|           1|
# |  1|     0|03-01-2022|           2|
# |  1|     0|04-01-2022|           3|
# |  1|     1|05-01-2022|           0|
# |  1|     0|06-01-2022|           1|
# |  2|     1|01-01-2022|           0|
# |  2|     0|02-01-2022|           1|
# |  2|     0|03-01-2022|           2|
# |  2|     1|04-01-2022|           0|
# |  2|     0|05-01-2022|           1|
#  --- ------ ---------- ------------ 

The format of your date does not look like a date type in Spark. So I created a separate column expression true_date in order to be used in window functions for correct ordering.

  • Related