Home > Back-end >  PySpark - window function results in new column
PySpark - window function results in new column

Time:12-01

I have the following data for PySpark:

 --------- --------- --------- ------------------- 
|event_id |user_id  |   status|         created_at|
 --------- --------- --------- ------------------- 
|        1|        2|        a|2017-05-26 15:12:54|
|        1|        2|        b|2017-05-26 15:12:53|
|        2|        1|        a|2017-05-26 15:12:56|
|        1|        2|        b|2017-05-26 16:12:57|
|        2|        1|        c|2017-05-26 16:12:58|
|        2|        1|        b|2017-05-26 16:12:58|
|        3|        1|        b|2017-05-26 14:17:58|
 --------- --------- --------- ------------------- 

For each pair (event_id, user_id) (this is the primary key, data is pulled from DB) I want to create new columns based on highest created_at for each status, with null value for pairs with no data. For data above:

 --------- --------- ------------------- ------------------- ------------------- 
|event_id |user_id  |                  a|                  b|                  c|
 --------- --------- ------------------- ------------------- ------------------- 
|        1|        2|2017-05-26 15:12:54|2017-05-26 16:12:57|               null|
|        2|        1|2017-05-26 15:12:56|               null|2017-05-26 16:12:58|
|        3|        1|               null|2017-05-26 14:17:58|               null|
 --------- --------- ------------------- ------------------- ------------------- 

My solution is quite complicated, slow and I'm pretty sure it can be optimized:

for status in ["a", "b", "c"]:
    df2 = df.filter(F.col("status") == status).groupBy(["event_id", "user_id"]).agg(F.max("created_at").alias(status))
    df = (
        df
        .join(
            df2, 
            on=(
                (df["event_id"] == df2["event_id"]) & 
                (df["user_id"] == df2["user_id"]) & 
                (df["status"] == status)
            ),
            how="left_outer"
        )
        .select(df["*"], status)
    )

df2 = (
    df
    .drop("status", "created_at")
    .groupBy(["event_id", "user_id"])
    .agg(F.max("a").alias("a"), F.max("b").alias("b"), F.max("c").alias("c"))
)

# df2 has the result

Can I avoid JOINs in the loop here, or at least reduce the JOIN groupBy and max to one step? As it is now, I just sequentially process statuses and this is not scalable at all.

CodePudding user response:

Try this,

df.groupBy("event_id","user_id").pivot("status").agg(first("created_at")).show
  • Related