Home > database >  PySpark - merge two DataFrames, overwriting one with the other
PySpark - merge two DataFrames, overwriting one with the other

Time:03-25

I'm trying to figure out an efficient way to merge two PySpark DataFrames like this:

from pyspark.sql import Row
data = [Row(id=index, value=val, calc=val*2) for index, val in enumerate(range(10))]
df = spark.createDataFrame(data=data, schema=["id", "value", "calc"])
data2 = [Row(id=index, value=val, calc=val**2) for index, val in [(9, 9), (10, 10)]]
df2 = spark.createDataFrame(data=data2, schema=["id", "value", "calc"])


df.head(10)
# Outputs: [Row(id=0, value=0, calc=0), Row(id=1, value=1, calc=2), Row(id=2, value=2, calc=4), Row(id=3, value=3, calc=6), Row(id=4, value=4, calc=8), Row(id=5, value=5, calc=10), Row(id=6, value=6, calc=12), Row(id=7, value=7, calc=14), Row(id=8, value=8, calc=16), Row(id=9, value=9, calc=18)]

df2.head(2)
# Outputs: [Row(id=9, value=9, calc=81), Row(id=10, value=10, calc=100)]

df3 = SOME_MERGE_FUNCTION(df, df2)
df3.head(20)

# Outputs: [Row(id=0, value=0, calc=0), Row(id=1, value=1, calc=2), Row(id=2, value=2, calc=4), Row(id=3, value=3, calc=6), Row(id=4, value=4, calc=8), Row(id=5, value=5, calc=10), Row(id=6, value=6, calc=12), Row(id=7, value=7, calc=14), Row(id=8, value=8, calc=16), Row(id=9, value=9, calc=81), Row(id=10, value=10, calc=100)]

So how do I write SOME_MERGE_FUNCTION()?

CodePudding user response:

You can anti join df and df2 and then union the result to df2. That way you have everything from df2 and only the rows from df which are not in df2.

df3 = df.join(df2, on=['id'], how='anti').unionAll(df2)
print(df3.head(20))

[Row(id=0, value=0, calc=0), Row(id=7, value=7, calc=14), Row(id=6, value=6, calc=12), Row(id=5, value=5, calc=10), Row(id=1, value=1, calc=2), Row(id=3, value=3, calc=6), Row(id=8, value=8, calc=16), Row(id=2, value=2, calc=4), Row(id=4, value=4, calc=8), Row(id=9, value=9, calc=81), Row(id=10, value=10, calc=100)]
  • Related