I have the below table where I have the increasing streak if the activity_date is consecutive. If not, streak will be set to 1.
Now I need to get the min and max of each group of streaks. Using Spark and scala or Spark SQL.
Input
floor activity_date streak
--------------------------------
floor1 2018-11-08 1
floor1 2019-01-24 1
floor1 2019-04-05 1
floor1 2019-04-08 1
floor1 2019-04-09 2
floor1 2019-04-14 1
floor1 2019-04-17 1
floor1 2019-04-20 1
floor2 2019-05-04 1
floor2 2019-05-05 2
floor2 2019-06-04 1
floor2 2019-07-28 1
floor2 2019-08-14 1
floor2 2019-08-22 1
Output
floor activity_date end_activity_date
----------------------------------------
floor1 2018-11-08 2018-11-08
floor1 2019-01-24 2019-01-24
floor1 2019-04-05 2019-04-05
floor1 2019-04-08 2019-04-09
floor1 2019-04-14 2019-04-14
floor1 2019-04-17 2019-04-17
floor1 2019-04-20 2019-04-20
floor2 2019-05-04 2019-05-05
floor2 2019-06-04 2019-06-04
floor2 2019-07-28 2019-07-28
floor2 2019-08-14 2019-08-14
floor2 2019-08-22 2019-08-22
CodePudding user response:
You may use the following approach
Using Spark SQL
SELECT
floor,
activity_date,
MAX(activity_date) OVER (PARTITION BY gn,floor) as end_activity_date
FROM (
SELECT
*,
SUM(is_same_streak) OVER (
PARTITION BY floor ORDER BY activity_date
) as gn
FROM (
SELECT
*,
CASE
WHEN streak > LAG(streak,1,streak-1) OVER (
PARTITION BY floor
ORDER BY activity_date
) THEN 0
ELSE 1
END as is_same_streak
FROM
df
) t1
) t2
ORDER BY
"floor",
activity_date
Using scala api
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window
val floorWindow = Window.partitionBy("floor").orderBy("activity_date")
val output = df.withColumn(
"is_same_streak",
when(
col("streak") > lag(col("streak"),1,col("streak")-1).over(floorWindow) , 0
).otherwise(1)
)
.withColumn(
"gn",
sum(col("is_same_streak")).over(floorWindow)
)
.select(
"floor",
"activity_date",
max(col("activity_date")).over(
Window.partitionBy("gn","floor")
).alias("end_activity_date")
)
Using pyspark api
from pyspark.sql import functions as F
from pyspark.sql import Window
floorWindow = Window.partitionBy("floor").orderBy("activity_date")
output = (
df.withColumn(
"is_same_streak",
F.when(
F.col("streak") > F.lag(F.col("streak"),1,F.col("streak")-1).over(floorWindow) , 0
).otherwise(1)
)
.withColumn(
"gn",
F.sum(F.col("is_same_streak")).over(floorWindow)
)
.select(
"floor",
"activity_date",
F.max(F.col("activity_date")).over(
Window.partitionBy("gn","floor")
).alias("end_activity_date")
)
)
Let me know if this works for you.