I have a column named info
defined as well:
| Timestamp | info |
------------------- ----------
|2016-01-01 17:54:30| 0 |
|2016-02-01 12:16:18| 0 |
|2016-03-01 12:17:57| 0 |
|2016-04-01 10:05:21| 0 |
|2016-05-11 18:58:25| 1 |
|2016-06-11 11:18:29| 1 |
|2016-07-01 12:05:21| 0 |
|2016-08-11 11:58:25| 0 |
|2016-09-11 15:18:29| 1 |
I would like to count the consecutive occurrences of 1s and insert 0 otherwise. The final column would be:
-------------------- ---------- ----------
| Timestamp | info | res |
------------------- ---------- ----------
|2016-01-01 17:54:30| 0 | 0 |
|2016-02-01 12:16:18| 0 | 0 |
|2016-03-01 12:17:57| 0 | 0 |
|2016-04-01 10:05:21| 0 | 0 |
|2016-05-11 18:58:25| 1 | 1 |
|2016-06-11 11:18:29| 1 | 2 |
|2016-07-01 12:05:21| 0 | 0 |
|2016-08-11 11:58:25| 0 | 0 |
|2016-09-11 15:18:29| 1 | 1 |
I tried using the following function, but it didn't work.
df_input = df_input.withColumn(
"res",
F.when(
df_input.info == F.lag(df_input.info).over(w1),
F.sum(F.lit(1)).over(w1)
).otherwise(0)
)
CodePudding user response:
From Adding a column counting cumulative pervious repeating values, credits to @blackbishop
from pyspark.sql import functions as F, Window
df = spark.createDataFrame([0, 0, 0, 0, 1, 1, 0, 0, 1], 'int').toDF('info')
df.withColumn("ID", F.monotonically_increasing_id()) \
.withColumn("group",
F.row_number().over(Window.orderBy("ID"))
- F.row_number().over(Window.partitionBy("info").orderBy("ID"))
) \
.withColumn("Result", F.when(F.col('info') != 0, F.row_number().over(Window.partitionBy("group").orderBy("ID"))).otherwise(F.lit(0)))\
.orderBy("ID")\
.drop("ID", "group")\
.show()
---- ------
|info|Result|
---- ------
| 0| 0|
| 0| 0|
| 0| 0|
| 0| 0|
| 1| 1|
| 1| 2|
| 0| 0|
| 0| 0|
| 1| 1|
---- ------
CodePudding user response:
tl;dr -- complicated approach
We had a similar problem and wanted a row-by-row processing approach which looked at the previous row's calculated field. There were multiple calculations to keep track of and we resorted to an rdd
approach and shipped our python functions to all workers for optimal distributed processing. Here something based on that approach.
creating a dummy data identical to your problem
data_ls = [
(1, 0,),
(2, 0,),
(3, 0,),
(4, 1,),
(5, 1,),
(6, 0,),
(7, 1,)
]
data_sdf = spark.sparkContext.parallelize(data_ls).toDF(['ts', 'info'])
# --- ----
# | ts|info|
# --- ----
# | 1| 0|
# | 2| 0|
# | 3| 0|
# | 4| 1|
# | 5| 1|
# | 6| 0|
# | 7| 1|
# --- ----
Our approach was to create a python function that keeps track of the previously calculated field in the current field. The function was used on the dataframe's rdd
with flatMapValues()
.
def custom_cumcount(groupedRows):
"""
keep track of the previously calculated result and use in the current calculation
ship this for optimum resource usage
"""
res = []
prev_sumcol = 0
for row in groupedRows:
if row.info == 0:
sum_col = 0
else:
sum_col = prev_sumcol row.info
prev_sumcol = sum_col
res.append([col for col in row] [sum_col])
return res
# create a schema to be used for result's dataframe
data_sdf_schema_new = data_sdf.withColumn('dropme', func.lit(None).cast('int')). \
drop('dropme'). \
schema. \
add('sum_col', 'integer')
# StructType(List(StructField(ts,LongType,true),StructField(info,LongType,true),StructField(sum_col,IntegerType,true)))
# run the function on the data
data_rdd = data_sdf.rdd. \
groupBy(lambda i: 1). \
flatMapValues(lambda k: custom_cumcount(sorted(k, key=lambda s: s.ts))). \
values()
# create dataframe from resulting rdd
spark.createDataFrame(data_rdd, schema=data_sdf_schema_new). \
show()
# --- ---- -------
# | ts|info|sum_col|
# --- ---- -------
# | 1| 0| 0|
# | 2| 0| 0|
# | 3| 0| 0|
# | 4| 1| 1|
# | 5| 1| 2|
# | 6| 0| 0|
# | 7| 1| 1|
# --- ---- -------
CodePudding user response:
Here's another way using conditional running sum to create groups then use that column for cumulative sum:
from pyspark.sql import Window, functions as F
w1 = Window.orderBy("Timestamp")
w2 = Window.partitionBy("grp").orderBy("Timestamp")
df1 = (df.withColumn("grp", F.sum(F.when(F.col("info") == 1, 0).otherwise(1)).over(w1))
.withColumn("res", F.sum("info").over(w2))
.drop("grp")
)
df1.show()
# ------------------- ---- ---
# | Timestamp|info|res|
# ------------------- ---- ---
# |2016-01-01 17:54:30| 0| 0|
# |2016-02-01 12:16:18| 0| 0|
# |2016-03-01 12:17:57| 0| 0|
# |2016-04-01 10:05:21| 0| 0|
# |2016-05-11 18:58:25| 1| 1|
# |2016-06-11 11:18:29| 1| 2|
# |2016-07-01 12:05:21| 0| 0|
# |2016-08-11 11:58:25| 0| 0|
# |2016-09-11 15:18:29| 1| 1|
# ------------------- ---- ---