Home > OS >  Collect range of dates as list in Spark
Collect range of dates as list in Spark

Time:09-07

I have the following DFs:

 -------------- --- ---- 
|Date          |Id |Cond|
 -------------- --- ---- 
|    2022-01-08|  1|   0|
|    2022-01-10|  1|   0|
|    2022-01-11|  1|   0|
|    2022-01-12|  1|   0|
|    2022-01-13|  1|   0|
|    2022-01-15|  1|   0|
|    2022-01-18|  1|   0|
|    2022-01-19|  1|   0|
|    2022-01-08|  2|   0|
|    2022-01-11|  2|   0|
|    2022-01-12|  2|   0|
|    2022-01-15|  2|   0|
|    2022-01-16|  2|   0|
|    2022-01-17|  2|   0|
|    2022-01-19|  2|   0|
|    2022-01-20|  2|   0|
 -------------- --- ---- 

 -------------- --- ---- 
|Date          |Id |Cond|
 -------------- --- ---- 
|    2022-01-09|  1|   1|
|    2022-01-14|  1|   1|
|    2022-01-16|  1|   1|
|    2022-01-17|  1|   1|
|    2022-01-20|  1|   1|
|    2022-01-09|  2|   1|
|    2022-01-10|  2|   1|
|    2022-01-13|  2|   1|
|    2022-01-14|  2|   1|
|    2022-01-18|  2|   1|
 -------------- --- ---- 

I want to get the first 2 dates of DF1 that has as sequence in DF2.

Example:

For date "2022-01-15" and Id = 1 in DF1 I need to collect dates "2022-01-14" and "2022-01-09" from DF2.

My expected output:

 -------------- --- ------------------------------ 
|Date          |Id |List                          |
 -------------- --- ------------------------------ 
|    2022-01-08|  1|  []                          |
|    2022-01-10|  1|  ['2022-01-09']              |
|    2022-01-11|  1|  ['2022-01-09']              |
|    2022-01-12|  1|  ['2022-01-09']              |
|    2022-01-13|  1|  ['2022-01-09']              |
|    2022-01-15|  1|  ['2022-01-14', '2022-01-09']|
|    2022-01-18|  1|  ['2022-01-17', '2022-01-16']|
|    2022-01-19|  1|  ['2022-01-17', '2022-01-16']|
|    2022-01-08|  2|  []                          |
|    2022-01-11|  2|  ['2022-01-10', '2022-01-09']|
|    2022-01-12|  2|  ['2022-01-10', '2022-01-09']|
|    2022-01-15|  2|  ['2022-01-14', '2022-01-13']|
|    2022-01-16|  2|  ['2022-01-14', '2022-01-13']|
|    2022-01-17|  2|  ['2022-01-14', '2022-01-13']|
|    2022-01-19|  2|  ['2022-01-18', '2022-01-14']|
|    2022-01-20|  2|  ['2022-01-18', '2022-01-14']|
 -------------- --- ------------------------------ 

I know that I can use collect_list to get the dates as a list, but how can I collect by range?

MVCE:

data_1 = [
    ("2022-01-08", 1, 0),
    ("2022-01-10", 1, 0),
    ("2022-01-11", 1, 0),
    ("2022-01-12", 1, 0),
    ("2022-01-13", 1, 0),
    ("2022-01-15", 1, 0),
    ("2022-01-18", 1, 0),
    ("2022-01-19", 1, 0), 
    ("2022-01-08", 2, 0),
    ("2022-01-11", 2, 0), 
    ("2022-01-12", 2, 0),
    ("2022-01-15", 2, 0), 
    ("2022-01-16", 2, 0),
    ("2022-01-17", 2, 0), 
    ("2022-01-19", 2, 0),
    ("2022-01-20", 2, 0) 
]
schema_1 = StructType([
    StructField("Date", StringType(), True),
    StructField("Id", IntegerType(), True),
    StructField("Cond", IntegerType(), True)
  ])
df_1 = spark.createDataFrame(data=data_1, schema=schema_1)

data_2 = [
    ("2022-01-09", 1, 1),
    ("2022-01-14", 1, 1),
    ("2022-01-16", 1, 1),
    ("2022-01-17", 1, 1),
    ("2022-01-20", 1, 1),
    ("2022-01-09", 2, 1),
    ("2022-01-10", 2, 1),
    ("2022-01-13", 2, 1), 
    ("2022-01-14", 2, 1),
    ("2022-01-18", 2, 1)
]
schema_2 = StructType([
    StructField("Date", StringType(), True),
    StructField("Id", IntegerType(), True),
    StructField("Cond", IntegerType(), True)
  ])
df_2 = spark.createDataFrame(data=data_2, schema=schema_2)

CodePudding user response:

You can accomplish this by:

  1. joining the two tables on Id;
  2. conditionally collecting dates from df_2 when they are earlier than the target date from df_1 (collect_list ignores null values by default); and
  3. using a combination of slice and sort_array to keep only the two most recent dates.
import pyspark.sql.functions as F

df_out = df_1 \
  .join(df_2.select(F.col("Date").alias("Date_RHS"), "Id"), on="Id", how="inner") \
  .groupBy("Date", "Id") \
  .agg(F.collect_list(F.when(F.col("Date_RHS") < F.col("Date"), F.col("Date_RHS")).otherwise(F.lit(None))).alias("List")) \
  .select("Date", "Id", F.slice(F.sort_array(F.col("List"), asc=False), start=1, length=2).alias("List"))

#  ---------- --- ------------------------ 
# |Date      |Id |List                    |
#  ---------- --- ------------------------ 
# |2022-01-08|1  |[]                      |
# |2022-01-10|1  |[2022-01-09]            |
# |2022-01-11|1  |[2022-01-09]            |
# |2022-01-12|1  |[2022-01-09]            |
# |2022-01-13|1  |[2022-01-09]            |
# |2022-01-15|1  |[2022-01-14, 2022-01-09]|
# |2022-01-18|1  |[2022-01-17, 2022-01-16]|
# |2022-01-19|1  |[2022-01-17, 2022-01-16]|
# |2022-01-08|2  |[]                      |
# |2022-01-11|2  |[2022-01-10, 2022-01-09]|
# |2022-01-12|2  |[2022-01-10, 2022-01-09]|
# |2022-01-15|2  |[2022-01-14, 2022-01-13]|
# |2022-01-16|2  |[2022-01-14, 2022-01-13]|
# |2022-01-17|2  |[2022-01-14, 2022-01-13]|
# |2022-01-19|2  |[2022-01-18, 2022-01-14]|
# |2022-01-20|2  |[2022-01-18, 2022-01-14]|
#  ---------- --- ------------------------ 

CodePudding user response:

The following approach will first aggregate df_2, then do a left join. Then, use the higher-order function filter to filter out dates which are bigger than column "Date" and slice to select just 2 max values from the array.

from pyspark.sql import functions as F

df = df_1.join(df_2.groupBy('Id').agg(F.collect_set('Date').alias('d2')), 'Id', 'left')
df = df.select(
    'Date', 'Id',
    F.slice(F.sort_array(F.filter('d2', lambda x: x < F.col('Date')), False), 1, 2).alias('List')
)

df.show(truncate=0)
#  ---------- --- ------------------------ 
# |Date      |Id |List                    |
#  ---------- --- ------------------------ 
# |2022-01-08|1  |[]                      |
# |2022-01-10|1  |[2022-01-09]            |
# |2022-01-11|1  |[2022-01-09]            |
# |2022-01-12|1  |[2022-01-09]            |
# |2022-01-13|1  |[2022-01-09]            |
# |2022-01-15|1  |[2022-01-14, 2022-01-09]|
# |2022-01-18|1  |[2022-01-17, 2022-01-16]|
# |2022-01-19|1  |[2022-01-17, 2022-01-16]|
# |2022-01-08|2  |[]                      |
# |2022-01-11|2  |[2022-01-10, 2022-01-09]|
# |2022-01-12|2  |[2022-01-10, 2022-01-09]|
# |2022-01-15|2  |[2022-01-14, 2022-01-13]|
# |2022-01-16|2  |[2022-01-14, 2022-01-13]|
# |2022-01-17|2  |[2022-01-14, 2022-01-13]|
# |2022-01-19|2  |[2022-01-18, 2022-01-14]|
# |2022-01-20|2  |[2022-01-18, 2022-01-14]|
#  ---------- --- ------------------------ 

For lower Spark versions, use this:

from pyspark.sql import functions as F

df = df_1.join(df_2.groupBy('Id').agg(F.collect_set('Date').alias('d2')), 'Id', 'left')
df = df.select(
    'Date', 'Id',
    F.slice(F.sort_array(F.expr("filter(d2, x -> x < Date)"), False), 1, 2).alias('List')
)
  • Related