Home > Blockchain >  How can find the find occurence in a window/group and fetch all unbounded rows prior to that?
How can find the find occurence in a window/group and fetch all unbounded rows prior to that?

Time:10-01

I have the following Dataframe that contains sessions along with the customer-id and b boolean indicating a purchase.

 --------- ---------- ----------- ----------- 
|accountId|customerId|sessionDate|didPurchase|
 --------- ---------- ----------- ----------- 
|walgreens|       001| 2021-09-30|      false|
|walgreens|       001| 2021-07-30|       true|
|walgreens|       001| 2021-05-30|       true|
|walgreens|       001| 2021-04-30|      false|
|walgreens|       002| 2021-08-04|       true|
|walgreens|       002| 2021-07-06|      false|
|walgreens|       002| 2021-07-05|      false|
|walgreens|       002| 2021-07-01|      false|
|    tesco|       001| 2021-09-21|      false|
|    tesco|       001| 2021-09-20|      false|
|    tesco|       001| 2021-01-30|      false|
|    tesco|       001| 2021-01-01|      true |
 --------- ---------- ----------- ----------- 

For each window of accountId and customerId, I need to get all rows prior to the oldest purchase (i.e. didPurchase should be true) when you group by accountId and customerId and sort by sessionDate descending.

My expected resultset should be so:

 --------- ---------- ----------- ----------- 
|accountId|customerId|sessionDate|didPurchase|
 --------- ---------- ----------- ----------- 
|walgreens|       001| 2021-05-30|       true|
|walgreens|       001| 2021-04-30|      false|
|walgreens|       002| 2021-08-04|       true|
|walgreens|       002| 2021-07-06|      false|
|walgreens|       002| 2021-07-05|      false|
|walgreens|       002| 2021-07-01|      false|
|    tesco|       001| 2021-01-01|      true |
 --------- ---------- ----------- ----------- 

The snippet below is what I've used in Zeppelin to create the df. I'm familiar with SQL window functions but baffled me.

   SELECT accountId
        , customerId
        , sessionDate
        , ROW_NUMBER() OVER (PARTITION BY accountId, customerId ORDER BY sessionDate DESC) AS rn
        , COUNT(*) OVER (PARTITION BY accountId, customerId) AS visits_per_customer
     FROM myvisits

Is this where I need to use UNBOUNDED PRECEDING? is so, a small SQL snippet would help. Thank you.

import java.sql.Date

val df = sc.parallelize(Seq(
  ("walgreens", "001", Date.valueOf("2021-09-30"),  false),
  ("walgreens", "001", Date.valueOf("2021-07-30"),  true), 
  ("walgreens", "001", Date.valueOf("2021-05-30"),  true), 
  ("walgreens", "001", Date.valueOf("2021-04-30"),  false),
  ("walgreens", "002", Date.valueOf("2021-08-04"),  true),
  ("walgreens", "002", Date.valueOf("2021-07-06"),  false), 
  ("walgreens", "002", Date.valueOf("2021-07-05"),  false), 
  ("walgreens", "002", Date.valueOf("2021-07-01"),  false),
  ("tesco",     "001", Date.valueOf("2021-09-21"),  false), 
  ("tesco",     "001", Date.valueOf("2021-09-20"),  false),
  ("tesco",     "001", Date.valueOf("2021-01-30"),  false),
  ("tesco",     "001", Date.valueOf("2021-01-01"),  true)))
  .toDF("accountId", "customerId", "sessionDate", "didPurchase")
  
df.registerTempTable("myvisits")
df.show()

CodePudding user response:

You may use the MIN window function with a CASE expression to identify the oldest purchase sessionDate and filter records that have occurred on or before that to achieve your desired results.

The sql to achieve this could look like

SELECT
    accountId,
    customerId,
    sessionDate,
    didPurchase
FROM (
    SELECT 
        *,
        MIN(CASE
            WHEN didPurchase THEN sessionDate
        END) OVER (
            PARTITION BY accountId, customerId
        ) as prior_oldest_purchase
     FROM
        myvisits
) v
WHERE
     v.sessionDate <= v.prior_oldest_purchase
ORDER BY
    accountId DESC, customerId,sessionDate DESC

View live demo working db fiddle

Using the scala api, this could look like

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window

val customerAccountWindow = Window.partitionBy("accountId","customerId")

val outputDf = df.withColumn(
                    "prior_oldest_purchase",
                    min(
                        when(col("didPurchase"),col("sessionDate"))
                    ).over(customerAccountWindow)
               )
               .where(col("sessionDate") <= col("prior_oldest_purchase"))
               .select(
                   col("accountId"),
                   col("customerId"),
                   col("sessionDate"),
                   col("didPurchase")
               )
               .orderBy(
                   col("accountId").desc(),
                   col("customerId"),
                   col("sessionDate").desc()
               )

or using the pyspark api


from pyspark.sql import functions as F
from pyspark.sql import Window

customerAccountWindow = Window.partitionBy("accountId","customerId")

outputDf = (
    df.withColumn(
        "prior_oldest_purchase",
        F.min(
            F.when(F.col("didPurchase"),F.col("sessionDate"))
        ).over(customerAccountWindow)
    )
    .where(F.col("sessionDate") <= F.col("prior_oldest_purchase"))
    .select(
        F.col("accountId"),
        F.col("customerId"),
        F.col("sessionDate"),
        F.col("didPurchase")
     )
     .orderBy(
        F.col("accountId").desc(),
        F.col("customerId"),
        F.col("sessionDate").desc()
      )
)

Let me know if this works for you.

  • Related