Home > Back-end >  Counting number of changes in categorical column in PySpark Dataframe
Counting number of changes in categorical column in PySpark Dataframe

Time:06-28

I have a PySpark dataframe that looks like this:

data = [(2010, 3, 12, 0, 'p1', 'state1'), 
         (2010, 3, 12, 0, 'p2', 'state2'), 
         (2010, 3, 12, 0, 'p3', 'state1'), 
         (2010, 3, 12, 0, 'p4', 'state2'), 
         (2010, 3, 12, 2, 'p1', 'state3'), 
         (2010, 3, 12, 2, 'p2', 'state1'), 
         (2010, 3, 12, 2, 'p3', 'state3'), 
         (2010, 3, 12, 4, 'p1', 'state1'), 
         (2010, 3, 12, 6, 'p1', 'state1')]

columns = ['year', 'month', 'day', 'hour', 'process_id','state']

df = spark.createDataFrame(data=data, schema=columns)
df.show()
 ---- ----- --- ---- ---------- ------ 
|year|month|day|hour|process_id| state|
 ---- ----- --- ---- ---------- ------ 
|2010|    3| 12|   0|        p1|state1|
|2010|    3| 12|   0|        p2|state2|
|2010|    3| 12|   0|        p3|state1|
|2010|    3| 12|   0|        p4|state2|
|2010|    3| 12|   2|        p1|state3|
|2010|    3| 12|   2|        p2|state1|
|2010|    3| 12|   2|        p3|state3|
|2010|    3| 12|   4|        p1|state1|
|2010|    3| 12|   6|        p1|state1|
 ---- ----- --- ---- ---------- ------ 

The dataframe is already sorted in an increasing order by the four columns: year, month, day and hour as above. The increment is in 2-hour interval.

I would like to find out, for each process_id, how many times its state changes within each day. For that, I intend to use groupby, something like this:

chg_count_df = df.groupby('process_id', 'year', 'month', 'day').
               agg(.....)

For this example, the expected output is:

 ---- ----- --- ---------- ---------- 
|year|month|day|process_id| chg_count|
 ---- ----- --- ---------- ---------- 
|2010|    3| 12|        p1|         2|
|2010|    3| 12|        p2|         1|
|2010|    3| 12|        p3|         1|
|2010|    3| 12|        p4|         0|
 ---- ----- --- ---------- ---------- 

What should go into the agg(...) function? Or is there a better to way to do this?

CodePudding user response:

chg_count_df = df.groupby('process_id', 'year', 'month', 'day').count()

CodePudding user response:

You could employ lag window function to check if a state was changed. Then groupBy using sum.

from pyspark.sql import functions as F, Window as W
w = W.partitionBy('year', 'month', 'day', 'process_id').orderBy(F.desc('hour'))
df = df.withColumn('change', F.coalesce((F.lag('state').over(w) != F.col('state')).cast('int'), F.lit(0)))
df = df.groupBy('year', 'month', 'day', 'process_id').agg(F.sum('change').alias('chg_count'))

df.show()
#  ---- ----- --- ---------- --------- 
# |year|month|day|process_id|chg_count|
#  ---- ----- --- ---------- --------- 
# |2010|    3| 12|        p1|        2|
# |2010|    3| 12|        p2|        1|
# |2010|    3| 12|        p3|        1|
# |2010|    3| 12|        p4|        0|
#  ---- ----- --- ---------- --------- 
  • Related