Home > Software engineering >  Pyspark removing duplicate columns after broadcast join
Pyspark removing duplicate columns after broadcast join

Time:12-01

I have two dataframes which I wish to join and then save as a parquet table. After performing the join my resulting table has duplicate columns, preventing me from saving the dataset.

Here is my code for the join

join_conditions = [
        df1.colX == df2.colY,
        df1.col1 == df2.col1,
        df1.col2 == df2.col2,
        df1.col3 == df2.col3,
    ]

dfj= df1.alias("1").join(F.broadcast(df2.alias("2")), join_conditions, "inner"
).drop("1.col1", "1.col2", "1.col3")

dfj.write.format("parquet").mode("overwrite").saveAsTable("table")

I expected that the drop would remove the duplicate columns but an exception is thrown saying they are still there, when I try to save the table. drop() doesn't throw an exception if the columns don't exist, which means that the alias is probably wrong / not working as I expect?

I cannot do the join conditions as a list of strings as this seems to cause an error when not all columns in the join condition are called the same on each DataFrame:

join_conditions = [
        df1.colX == df2.colY,
        "col1",
        "col2",
        "col3"
    ]

doesn't work for example.

This join works but still results in the duplicate columns

join_conditions = [
        df1.X == df2.colY,
        F.col("1.col1") == F.col("2.col1"),
        F.col("1.col2") == F.col("2.col2"),
        F.col("1.col3") == F.col("2.col3"),
    ]

also didn't work. All of these methods still result in the joined dataframe having the duplicate columns col1, col2 and col3. What am I doing wrong / not understanding correctly? Answers with pyspark sample code would be appreciated.

CodePudding user response:

Im not sure why it doesn't work, its really weird.
This isn't so pretty but it works


from pyspark.sql import functions as F

data = [{'colX': "hello", 'col1': 1, 'col2': 2, 'col3': 3}]
data2 = [{'colY': "hello", 'col1': 1, 'col2': 2, 'col3': 3}]
df1 = spark.createDataFrame(data)
df2 = spark.createDataFrame(data2)

join_cond = [df1.colX==df2.colY, 
df1.col1==df2.col1, 
df1.col2==df2.col2, 
df1.col3==df2.col3]

df1.join(F.broadcast(df2), join_cond, 'inner').drop(df1.col1).drop(df1.col2).drop(df1.col3).printSchema()
root
 |-- colX: string (nullable = true)
 |-- col1: long (nullable = true)
 |-- col2: long (nullable = true)
 |-- col3: long (nullable = true)
 |-- colY: string (nullable = true)
  • Related