import pandas as pd
temp = pd.DataFrame(data=[['a',0],['a',0],['a',0],['b',0],['b',1],['b',1],['c',1],['c',0],['c',0]], columns=['ID','X'])
temp['transformed'] = temp.groupby('ID').apply(lambda x: (x["X"].shift() != x["X"]).cumsum()).reset_index()['X']
print(temp)
My question is how to achieve in pyspark.
CodePudding user response:
Pyspark have handle these type of queries with Windows utility functions. you can read its documentation here
Your pyspark code would be something like this :
window = W.partitionBy('id').orderBy('time'?)
new_df = (
df
.withColumn('shifted', F.lag('X').over(window))
.withColumn('cumsum', F.sum('X').over(window))
.filter(F.col('shifted') != F.col('cumsum'))
)