I have a dataset such as:
date | is_business_day |
---|---|
2023-01-01 | 0 |
2023-01-02 | 1 |
2023-01-03 | 1 |
2023-01-04 | 1 |
2023-01-05 | 1 |
2023-01-06 | 1 |
2023-01-07 | 0 |
2023-01-08 | 0 |
2023-01-09 | 1 |
2023-04-06 | 1 |
2023-04-07 | 0 |
2023-04-08 | 0 |
2023-04-09 | 0 |
2023-04-10 | 1 |
I would like to get the next value from date
column when condition is_business_day == 1
was met
The desired output would be something like:
date | is_business_day | next_business_day |
---|---|---|
2023-01-01 | 0 | 2023-01-02 |
2023-01-02 | 1 | 2023-01-03 |
2023-01-03 | 1 | 2023-01-04 |
2023-01-04 | 1 | 2023-01-05 |
2023-01-05 | 1 | 2023-01-06 |
2023-01-06 | 1 | 2023-01-09 |
2023-01-07 | 0 | 2023-01-09 |
2023-01-08 | 0 | 2023-01-09 |
2023-01-09 | 1 | 2023-01-10 |
2023-01-10 | 1 | 2023-01-14 |
2023-01-11 | 0 | 2023-01-14 |
2023-01-12 | 0 | 2023-01-14 |
2023-01-13 | 0 | 2023-01-14 |
2023-01-14 | 1 | ... |
I have created a function such as below:
def next_business_day(df_calendar, date):
date_f = datetime.strptime(date, '%Y-%m-%d')
next_day = (date_f timedelta(days=1)).strftime('%Y-%m-%d')
# Filtering DataFrame to only get the dates AFTER the date that we
# are checking.
df_calendar_next_days = df_calendar.filter(col('date') >= next_day)
# Creates an `list` that contains all rows from DataFrame
# so we can iterate over it.
df_it = df_calendar_next_days.collect()
is_business_day = 0
while is_business_day == 0:
for row in df_it:
is_business_day = row['is_business_day']
# If is "is_business_day == 1" then return the date
# else iterate over the next row of DataFrame
if is_business_day == 1:
return row['date']
The function works but I can't use this function on ".withColum()" because I can't pass the DataFrame as parameter.
If I try to do a code like this:
df_calendar = (
df_calendar
.withColumn('next_business_day', next_business_day(df_calendar, col('date')))
)
I receive the error:
TypeError: Invalid argument, not a string or column: DataFrame[date: date] of type <class 'pyspark.sql.dataframe.DataFrame'>. For column literals, use 'lit', 'array', 'struct' or 'create_map' function.
CodePudding user response:
I didn't debug what is wrong with the current code but what you want to achieve can be done with Pyspark's built-in function.
With conditional F.min
function, it will look for the minimum date when is_business_day == 1
and only look up from the next row(current row 1).
from pyspark.sql import functions as F
w = Window.orderBy('date').rowsBetween(Window.currentRow 1, Window.unboundedFollowing)
df = df.withColumn('next_business_day', F.min(F.when(F.col('is_business_day') == 1, F.col('date'))).over(w))