I have this dataframe:
--- ---------- ------
| id| date|amount|
--- ---------- ------
|123|2022-11-11|100.00|
|123|2022-11-12|100.00|
|123|2022-11-13|100.00|
|123|2022-11-14|200.00|
|456|2022-11-14|300.00|
|456|2022-11-15|300.00|
|456|2022-11-16|300.00|
|789|2022-11-11|400.00|
|789|2022-11-12|500.00|
--- ---------- ------
I need to create new records for each date until current_date() - 2
. And the value that will be populated must be the most recent one.
For example, if date_sub(current_date(), 2) == "2022-11-16"
then I need the following dataframe:
------ ---------- -------
|id | date | amount|
------ ---------- -------
| 123|2022-11-11|100,00 |
| 123|2022-11-12|100,00 |
| 123|2022-11-13|100,00 |
| 123|2022-11-14|200,00 |
| 123|2022-11-15|200,00 |
| 123|2022-11-16|200,00 |
| 456|2022-11-14|300,00 |
| 456|2022-11-15|300,00 |
| 456|2022-11-16|300,00 |
| 789|2022-11-11|400,00 |
| 789|2022-11-12|500,00 |
| 789|2022-11-13|500,00 |
| 789|2022-11-14|500,00 |
| 789|2022-11-15|500,00 |
| 789|2022-11-16|500,00 |
------ ---------- -------
import findspark
findspark.init()
import pyspark
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[4]").appName("Complete Rows").getOrCreate()
from pyspark.sql.functions import *
from pyspark.sql.types import StructType,StructField, IntegerType, DateType, DecimalType
from datetime import datetime
from decimal import Decimal
vdata = [
(123,datetime.strptime('2022-11-11','%Y-%m-%d'),Decimal(100)),
(123,datetime.strptime('2022-11-12','%Y-%m-%d'),Decimal(100)),
(123,datetime.strptime('2022-11-13','%Y-%m-%d'),Decimal(100)),
(123,datetime.strptime('2022-11-14','%Y-%m-%d'),Decimal(200)),
(456,datetime.strptime('2022-11-14','%Y-%m-%d'),Decimal(300)),
(456,datetime.strptime('2022-11-15','%Y-%m-%d'),Decimal(300)),
(456,datetime.strptime('2022-11-16','%Y-%m-%d'),Decimal(300)),
(789,datetime.strptime('2022-11-11','%Y-%m-%d'),Decimal(400)),
(789,datetime.strptime('2022-11-12','%Y-%m-%d'),Decimal(500))]
schema = StructType([
StructField("id",IntegerType(),False),
StructField("date",DateType(),False),
StructField("amount",DecimalType(10,2),False)])
df = spark.createDataFrame(vdata,schema)
df.show()
I tried to identify the maximum date for each ID, then identify the last value for that maximum date and do an F.expr(sequence)
to create a list of records and then explode to create the lines, but it's not working very well. Thanks for any help you can give!
CodePudding user response:
I managed to find the following solution.
For clarification purposes I divided it in three steps; of course you can write fewer lines of code if you make them more compact.
1) Lookup
Create a lookup table with all the necessary dates (both present and not) for each id.
import pyspark.sql.functions as F
from pyspark.sql.window import Window
lookup = (df
.groupby('id')
.agg(
F.min('date').alias('start_date'),
F.date_sub(F.current_date(), 2).alias('end_date')
)
.select('id', F.explode(F.expr('sequence(start_date, end_date, interval 1 day)')).alias('date'))
)
lookup.show()
--- ----------
| id| date|
--- ----------
|123|2022-11-11|
|123|2022-11-12|
|123|2022-11-13|
|123|2022-11-14|
|123|2022-11-15|
|123|2022-11-16|
|456|2022-11-14|
|456|2022-11-15|
|456|2022-11-16|
|789|2022-11-11|
|789|2022-11-12|
|789|2022-11-13|
|789|2022-11-14|
|789|2022-11-15|
|789|2022-11-16|
--- ----------
2) Join
Afterwards, we join the lookup table with our original dataframe: in this way the necessary rows are added with amount
variable set as null.
df = df.join(lookup, on=['id', 'date'], how='outer')
df.show()
--- ---------- ------
| id| date|amount|
--- ---------- ------
|123|2022-11-11| 100.0|
|123|2022-11-12| 100.0|
|123|2022-11-13| 100.0|
|123|2022-11-14| 200.0|
|123|2022-11-15| null|
|123|2022-11-16| null|
|456|2022-11-14| 300.0|
|456|2022-11-15| 300.0|
|456|2022-11-16| 300.0|
|789|2022-11-11| 400.0|
|789|2022-11-12| 500.0|
|789|2022-11-13| null|
|789|2022-11-14| null|
|789|2022-11-15| null|
|789|2022-11-16| null|
--- ---------- ------
3) last
function
We use the last
function with ignorenulls=True
to retrieve the last non-null value within a window partitioned by id and ordered by date.
w = Window.partitionBy('id').orderBy('date').rowsBetween(Window.unboundedPreceding, 0)
df = df.withColumn('amount', F.last('amount', ignorenulls=True).over(w))
df.show()
--- ---------- ------
| id| date|amount|
--- ---------- ------
|123|2022-11-11| 100.0|
|123|2022-11-12| 100.0|
|123|2022-11-13| 100.0|
|123|2022-11-14| 200.0|
|123|2022-11-15| 200.0|
|123|2022-11-16| 200.0|
|456|2022-11-14| 300.0|
|456|2022-11-15| 300.0|
|456|2022-11-16| 300.0|
|789|2022-11-11| 400.0|
|789|2022-11-12| 500.0|
|789|2022-11-13| 500.0|
|789|2022-11-14| 500.0|
|789|2022-11-15| 500.0|
|789|2022-11-16| 500.0|
--- ---------- ------