I am trying to write a merge_asof function of pandas in Spark.
Here is a sample example:
df1 = spark.createDataFrame(
[
(datetime(2019,2,3,13,30,0,23),"GOOG",720.5,720.93),
(datetime(2019,2,3,13,30,0,23),"MSFT",51.95,51.96),
(datetime(2019,2,3,13,30,0,20),"MSFT",51.97,51.98),
(datetime(2019,2,3,13,30,0,41),"MSFT",51.99,52.0),
(datetime(2019,2,3,13,30,0,48),"GOOG",720.5,720.93),
(datetime(2019,2,3,13,30,0,49),"AAPL",97.99,98.01),
(datetime(2019,2,3,13,30,0,72),"GOOG",720.5,720.88),
(datetime(2019,2,3,13,30,0,75),"MSFT",52.1,52.03)
],
(
"time",
"ticker",
"bid",
"ask"
)
)
df2 = spark.createDataFrame(
[
(datetime(2019,2,3,13,30,0,23),"MSFT",51.95,75),
(datetime(2019,2,3,13,30,0,38),"MSFT",51.95,155),
(datetime(2019,2,3,13,30,0,48),"GOOG",720.77,100),
(datetime(2019,2,3,13,30,0,48),"GOOG",720.92,100),
(datetime(2019,2,3,13,30,0,48),"AAPL",98.0,100),
],
(
"time",
"ticker",
"price",
"quantity"
)
)
Python
d1 = df1.toPandas().sort_values("time", ascending=True)
d2 = df2.toPandas().sort_values("time", ascending=True)
pd.merge_asof(d2, d1, on='time', by='ticker')
Outputs:
time ticker price quantity bid ask 0 2019-02-03 13:30:00.000023 MSFT 51.95 75 51.95 51.96 1 2019-02-03 13:30:00.000038 MSFT 51.95 155 51.95 51.96 2 2019-02-03 13:30:00.000048 GOOG 720.77 100 720.50 720.93 3 2019-02-03 13:30:00.000048 GOOG 720.92 100 720.50 720.93 4 2019-02-03 13:30:00.000048 AAPL 98.00 100 NaN NaN
Using UDF in spark
def asof_join(l, r):
return pd.merge_asof(l, r, on="time", by="ticker")
df2.sort("time").groupby("ticker").cogroup(df1.sort("time").groupby("ticker")).applyInPandas(
asof_join, schema="time timestamp, ticker string, price float,quantity int,bid float, ask float"
).show(10, False)
Output:
-------------------------- ------ ------ -------- ----- ------ |time |ticker|price |quantity|bid |ask | -------------------------- ------ ------ -------- ----- ------ |2019-02-03 13:30:00.000048|AAPL |98.0 |100 |null |null | |2019-02-03 13:30:00.000048|GOOG |720.77|100 |720.5|720.93| |2019-02-03 13:30:00.000048|GOOG |720.92|100 |720.5|720.93| |2019-02-03 13:30:00.000023|MSFT |51.95 |75 |51.95|51.96 | |2019-02-03 13:30:00.000038|MSFT |51.95 |155 |51.95|51.96 | -------------------------- ------ ------ -------- ----- ------
NOTE
The udf works and gives me the right results but I wanted to know if there is a more efficient way to do in PySpark using window functions? I am processing large data and udf is the cause of a bottleneck.
CodePudding user response:
You can do it by first joining and then using last
over window:
from pyspark.sql import functions as F, Window as W
df = df2.join(df1, ['time', 'ticker'], 'left')
w = W.partitionBy('ticker').orderBy('time')
df = df.withColumn('bid', F.coalesce('bid', F.last('bid', True).over(w)))
df = df.withColumn('ask', F.coalesce('ask', F.last('ask', True).over(w)))
df.show(truncate=0)
# -------------------------- ------ ------ -------- ----- ------
# |time |ticker|price |quantity|bid |ask |
# -------------------------- ------ ------ -------- ----- ------
# |2019-02-03 13:30:00.000048|AAPL |98.0 |100 |null |null |
# |2019-02-03 13:30:00.000048|GOOG |720.77|100 |720.5|720.93|
# |2019-02-03 13:30:00.000048|GOOG |720.92|100 |720.5|720.93|
# |2019-02-03 13:30:00.000023|MSFT |51.95 |75 |51.95|51.96 |
# |2019-02-03 13:30:00.000038|MSFT |51.95 |155 |51.95|51.96 |
# -------------------------- ------ ------ -------- ----- ------