Take below data for example
df = spark.createDataFrame(
[
('james','2019',0)
,('james','2020',1)
,('james','2021',1)
,('nik','2019',0)
,('nik','2020',1)
], ['name', 'year', 'flag'])
I want to add one more column "first_year_of_flag" which is the year when flag became 1 for the very first time, using pyspark.
so output should look like this -
name | year | flag | first_year_of_flag |
---|---|---|---|
james | 2019 | 0 | 2020 |
james | 2020 | 1 | 2020 |
james | 2021 | 1 | 2020 |
nik | 2019 | 0 | 2020 |
nik | 2020 | 1 | 2020 |
CodePudding user response:
I think Window
join
is suitable for your case, you can first create a reference table:
# You can use group by -> first / min as well
ref_df = df\
.filter(func.col('flag')==1)\
.select(
'name',
func.first('year').over(Window.partitionBy('name').orderBy(func.asc('year'))).alias('first_year')
).distinct()
ref_df.show(10, False)
----- ----------
|name |first_year|
----- ----------
|james|2020 |
|nik |2020 |
----- ----------
Then just do a left join with broadcasting:
df = df.join(func.broadcast(ref_df), on='name', how='left')
df.show(100, False)
----- ---- ---- ----------
|name |year|flag|first_year|
----- ---- ---- ----------
|james|2019|0 |2020 |
|james|2020|1 |2020 |
|james|2021|1 |2020 |
|nik |2019|0 |2020 |
|nik |2020|1 |2020 |
----- ---- ---- ----------
CodePudding user response:
A when().otherwise()
within a min
window would work.
data_sdf. \
withColumn('first_year_of_flag',
func.min(func.when(func.col('flag') == 1, func.col('year'))).
over(wd.partitionBy('name'))
). \
show()
# ----- ---- ---- ------------------
# | name|year|flag|first_year_of_flag|
# ----- ---- ---- ------------------
# | nik|2019| 0| 2020|
# | nik|2020| 1| 2020|
# |james|2019| 0| 2020|
# |james|2020| 1| 2020|
# |james|2021| 1| 2020|
# ----- ---- ---- ------------------
Essentially, only find min of years for which the flag is equal to 1 in the partition.
CodePudding user response:
you can use window functions
to get sum of flag value then determine when the flag value is 1
by name
and ordering by year
.
Example:
from pyspark.sql import Window
from pyspark.sql.functions import *
df = spark.createDataFrame(
[
('james','2019',0)
,('james','2020',1)
,('james','2021',1)
,('nik','2019',0)
,('nik','2020',1)
], ['name', 'year', 'flag'])
aggregation_window = Window.partitionBy(col("name")).orderBy(col("year"))
aggregation_max = Window.partitionBy(col("name")).orderBy(col("year")).rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
df=df.withColumn('aggregation_sum', sum('flag').over(aggregation_window)).\
withColumn("first_flag",when(col("aggregation_sum")=="1",lit(1))).\
withColumn("first_flag",expr("if(first_flag=1,year,0)")).\
withColumn("first_flag",max('first_flag').over(aggregation_max)).\
drop(*["aggregation_sum"])
df.show()
----- ---- ---- ----------
| name|year|flag|first_flag|
----- ---- ---- ----------
|james|2019| 0| 2020|
|james|2020| 1| 2020|
|james|2021| 1| 2020|
| nik|2019| 0| 2020|
| nik|2020| 1| 2020|
----- ---- ---- ----------