Shortened example:
vals1 = [(1, "a"),
(2, "b"),
]
columns1 = ["id","name"]
df1 = spark.createDataFrame(data=vals1, schema=columns1)
vals2 = [(1, "k"),
]
columns2 = ["id","name"]
df2 = spark.createDataFrame(data=vals2, schema=columns2)
df1 = df1.alias('df1').join(df2.alias('df2'), 'id', 'full')
df1.show()
The result has one column named id
and two columns named name
. How do I rename the columns with duplicate names, assuming that the real dataframes have tens of such columns?
CodePudding user response:
You can rename duplicate columns before join, except for columns required for join:
import pyspark.sql.functions as F
def add_prefix(df, prefix, columns=None):
if not columns:
columns = df.columns
return df.select(*[F.col(c).alias(prefix c if c in columns else c) for c in df.columns])
def add_suffix(df, suffix, columns=None):
if not columns:
columns = df.columns
return df.select(*[F.col(c).alias(c suffix if c in columns else c) for c in df.columns])
join_cols = ['id']
columns_to_rename = [c for c in df1.columns if c in df2.columns and c not in join_cols]
df2 = add_suffix(df2, '_y', columns_to_rename)
df3 = df1.join(df2, *join_cols, 'full')
--- ---- ------
| id|name|name_y|
--- ---- ------
| 1| a| k|
| 2| b| null|
--- ---- ------
CodePudding user response:
@quaziqarta proposed a method to rename columns before the join, note that you can also rename them after the join:
join_column = 'id'
df1 = df1.join(df2, join_column, 'full') \
.select(
[join_column]
[df1.alias('df1')['df1.' c].alias(c "_1") for c in df1.columns if c != join_column]
[df2.alias('df2')['df2.' c].alias(c "_2") for c in df2.columns if c != join_column]
) \
.show()
--- ------ ------
| id|name_1|name_2|
--- ------ ------
| 1| a| k|
| 2| b| null|
--- ------ ------
You only need to alias the dataframes (as you did in your example) in order to be able to specify which column you are referring when you ask Spark to get the column "name".
CodePudding user response:
Another method to rename only the intersecting columns
from typing import List
from pyspark.sql import DataFrame
def join_intersect(df_left: DataFrame, df_right: DataFrame, join_cols: List[str], how: str = 'inner'):
intersected_cols = set(df1.columns).intersection(set(df2.columns))
cols_to_rename = [c for c in intersected_cols if c not in join_cols]
for c in cols_to_rename:
df_left = df_left.withColumnRenamed(c, f"{c}__1")
df_right = df_right.withColumnRenamed(c, f"{c}__2")
return df_left.join(df_right, on=join_cols, how=how)
vals1 = [(1, "a"), (2, "b")]
columns1 = ["id", "name"]
df1 = spark.createDataFrame(data=vals1, schema=columns1)
vals2 = [(1, "k")]
columns2 = ["id", "name"]
df2 = spark.createDataFrame(data=vals2, schema=columns2)
df_joined = join_intersect(df1, df2, ['name'])
df_joined.show()