I have a PySpark dataframe that looks like this:
data = [(2010, 3, 12, 0, 'p1', 'state1'),
(2010, 3, 12, 0, 'p2', 'state2'),
(2010, 3, 12, 0, 'p3', 'state1'),
(2010, 3, 12, 0, 'p4', 'state2'),
(2010, 3, 12, 2, 'p1', 'state3'),
(2010, 3, 12, 2, 'p2', 'state1'),
(2010, 3, 12, 2, 'p3', 'state3'),
(2010, 3, 12, 4, 'p1', 'state1'),
(2010, 3, 12, 6, 'p1', 'state1')]
columns = ['year', 'month', 'day', 'hour', 'process_id','state']
df = spark.createDataFrame(data=data, schema=columns)
df.show()
---- ----- --- ---- ---------- ------
|year|month|day|hour|process_id| state|
---- ----- --- ---- ---------- ------
|2010| 3| 12| 0| p1|state1|
|2010| 3| 12| 0| p2|state2|
|2010| 3| 12| 0| p3|state1|
|2010| 3| 12| 0| p4|state2|
|2010| 3| 12| 2| p1|state3|
|2010| 3| 12| 2| p2|state1|
|2010| 3| 12| 2| p3|state3|
|2010| 3| 12| 4| p1|state1|
|2010| 3| 12| 6| p1|state1|
---- ----- --- ---- ---------- ------
The dataframe is already sorted in an increasing order by the four columns: year
, month
, day
and hour
as above. The increment is in 2-hour interval.
I would like to find out, for each process_id
, how many times its state changes within each day
. For that, I intend to use groupby
, something like this:
chg_count_df = df.groupby('process_id', 'year', 'month', 'day').
agg(.....)
For this example, the expected output is:
---- ----- --- ---------- ----------
|year|month|day|process_id| chg_count|
---- ----- --- ---------- ----------
|2010| 3| 12| p1| 2|
|2010| 3| 12| p2| 1|
|2010| 3| 12| p3| 1|
|2010| 3| 12| p4| 0|
---- ----- --- ---------- ----------
What should go into the agg(...)
function? Or is there a better to way to do this?
CodePudding user response:
chg_count_df = df.groupby('process_id', 'year', 'month', 'day').count()
CodePudding user response:
You could employ lag
window function to check if a state was changed. Then groupBy
using sum
.
from pyspark.sql import functions as F, Window as W
w = W.partitionBy('year', 'month', 'day', 'process_id').orderBy(F.desc('hour'))
df = df.withColumn('change', F.coalesce((F.lag('state').over(w) != F.col('state')).cast('int'), F.lit(0)))
df = df.groupBy('year', 'month', 'day', 'process_id').agg(F.sum('change').alias('chg_count'))
df.show()
# ---- ----- --- ---------- ---------
# |year|month|day|process_id|chg_count|
# ---- ----- --- ---------- ---------
# |2010| 3| 12| p1| 2|
# |2010| 3| 12| p2| 1|
# |2010| 3| 12| p3| 1|
# |2010| 3| 12| p4| 0|
# ---- ----- --- ---------- ---------