Home > database >  Get first date of occurrence in pyspark
Get first date of occurrence in pyspark

Time:09-02

Take below data for example

df = spark.createDataFrame(
  [
     ('james','2019',0)
    ,('james','2020',1)
    ,('james','2021',1)
    ,('nik','2019',0)
    ,('nik','2020',1)

  ], ['name', 'year', 'flag'])

I want to add one more column "first_year_of_flag" which is the year when flag became 1 for the very first time, using pyspark.

so output should look like this -

name year flag first_year_of_flag
james 2019 0 2020
james 2020 1 2020
james 2021 1 2020
nik 2019 0 2020
nik 2020 1 2020

CodePudding user response:

I think Window join is suitable for your case, you can first create a reference table:

# You can use group by -> first / min as well
ref_df = df\
    .filter(func.col('flag')==1)\
    .select(
        'name',
        func.first('year').over(Window.partitionBy('name').orderBy(func.asc('year'))).alias('first_year')
    ).distinct()

ref_df.show(10, False)
 ----- ---------- 
|name |first_year|
 ----- ---------- 
|james|2020      |
|nik  |2020      |
 ----- ---------- 

Then just do a left join with broadcasting:

df = df.join(func.broadcast(ref_df), on='name', how='left')
df.show(100, False)
 ----- ---- ---- ---------- 
|name |year|flag|first_year|
 ----- ---- ---- ---------- 
|james|2019|0   |2020      |
|james|2020|1   |2020      |
|james|2021|1   |2020      |
|nik  |2019|0   |2020      |
|nik  |2020|1   |2020      |
 ----- ---- ---- ---------- 

CodePudding user response:

A when().otherwise() within a min window would work.

data_sdf. \
    withColumn('first_year_of_flag',
               func.min(func.when(func.col('flag') == 1, func.col('year'))).
               over(wd.partitionBy('name'))
               ). \
    show()

#  ----- ---- ---- ------------------ 
# | name|year|flag|first_year_of_flag|
#  ----- ---- ---- ------------------ 
# |  nik|2019|   0|              2020|
# |  nik|2020|   1|              2020|
# |james|2019|   0|              2020|
# |james|2020|   1|              2020|
# |james|2021|   1|              2020|
#  ----- ---- ---- ------------------ 

Essentially, only find min of years for which the flag is equal to 1 in the partition.

CodePudding user response:

you can use window functions to get sum of flag value then determine when the flag value is 1 by name and ordering by year.

Example:

from pyspark.sql import Window
from pyspark.sql.functions import *
df = spark.createDataFrame(
  [
     ('james','2019',0)
    ,('james','2020',1)
    ,('james','2021',1)
    ,('nik','2019',0)
    ,('nik','2020',1)

  ], ['name', 'year', 'flag'])

aggregation_window = Window.partitionBy(col("name")).orderBy(col("year"))

aggregation_max = Window.partitionBy(col("name")).orderBy(col("year")).rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)

df=df.withColumn('aggregation_sum', sum('flag').over(aggregation_window)).\
withColumn("first_flag",when(col("aggregation_sum")=="1",lit(1))).\
withColumn("first_flag",expr("if(first_flag=1,year,0)")).\
withColumn("first_flag",max('first_flag').over(aggregation_max)).\
drop(*["aggregation_sum"])
df.show()

 ----- ---- ---- ---------- 
| name|year|flag|first_flag|
 ----- ---- ---- ---------- 
|james|2019|   0|      2020|
|james|2020|   1|      2020|
|james|2021|   1|      2020|
|  nik|2019|   0|      2020|
|  nik|2020|   1|      2020|
 ----- ---- ---- ---------- 
  • Related