I'm trying to figure out an efficient way to merge two PySpark DataFrames like this:
from pyspark.sql import Row
data = [Row(id=index, value=val, calc=val*2) for index, val in enumerate(range(10))]
df = spark.createDataFrame(data=data, schema=["id", "value", "calc"])
data2 = [Row(id=index, value=val, calc=val**2) for index, val in [(9, 9), (10, 10)]]
df2 = spark.createDataFrame(data=data2, schema=["id", "value", "calc"])
df.head(10)
# Outputs: [Row(id=0, value=0, calc=0), Row(id=1, value=1, calc=2), Row(id=2, value=2, calc=4), Row(id=3, value=3, calc=6), Row(id=4, value=4, calc=8), Row(id=5, value=5, calc=10), Row(id=6, value=6, calc=12), Row(id=7, value=7, calc=14), Row(id=8, value=8, calc=16), Row(id=9, value=9, calc=18)]
df2.head(2)
# Outputs: [Row(id=9, value=9, calc=81), Row(id=10, value=10, calc=100)]
df3 = SOME_MERGE_FUNCTION(df, df2)
df3.head(20)
# Outputs: [Row(id=0, value=0, calc=0), Row(id=1, value=1, calc=2), Row(id=2, value=2, calc=4), Row(id=3, value=3, calc=6), Row(id=4, value=4, calc=8), Row(id=5, value=5, calc=10), Row(id=6, value=6, calc=12), Row(id=7, value=7, calc=14), Row(id=8, value=8, calc=16), Row(id=9, value=9, calc=81), Row(id=10, value=10, calc=100)]
So how do I write SOME_MERGE_FUNCTION()
?
CodePudding user response:
You can anti join df
and df2
and then union the result to df2
. That way you have everything from df2
and only the rows from df
which are not in df2
.
df3 = df.join(df2, on=['id'], how='anti').unionAll(df2)
print(df3.head(20))
[Row(id=0, value=0, calc=0), Row(id=7, value=7, calc=14), Row(id=6, value=6, calc=12), Row(id=5, value=5, calc=10), Row(id=1, value=1, calc=2), Row(id=3, value=3, calc=6), Row(id=8, value=8, calc=16), Row(id=2, value=2, calc=4), Row(id=4, value=4, calc=8), Row(id=9, value=9, calc=81), Row(id=10, value=10, calc=100)]