I need to find all occurrences of duplicate records in a PySpark DataFrame. Following is the sample dataset:
# Prepare Data
data = [("A", "A", 1), \
("A", "A", 2), \
("A", "A", 3), \
("A", "B", 4), \
("A", "B", 5), \
("A", "C", 6), \
("A", "D", 7), \
("A", "E", 8), \
]
# Create DataFrame
columns= ["col_1", "col_2", "col_3"]
df = spark.createDataFrame(data = data, schema = columns)
df.show(truncate=False)
When I try the following code:
primary_key = ['col_1', 'col_2']
duplicate_records = df.exceptAll(df.dropDuplicates(primary_key))
duplicate_records.show()
The output will be:
As you can see, I don't get all occurrences of duplicate records based on the Primary Key since one instance of duplicate records is present in "df.dropDuplicates(primary_key)". The 1st and the 4th records of the dataset must be in the output.
Any idea to solve this issue?
CodePudding user response:
The reason you cant see 1st and the 4th records is dropduplicate keep one of each duplicates. see the code below:
primary_key = ['col_1', 'col_2']
df.dropDuplicates(primary_key).show()
----- ----- -----
|col_1|col_2|col_3|
----- ----- -----
| A| A| 1|
| A| B| 4|
| A| C| 6|
| A| D| 7|
| A| E| 8|
----- ----- -----
For your task you can extract duplicated keys and join it with your main dataframe:
duplicated_keys = (
df
.groupby(primary_key)
.count()
.filter(F.col('count') > 1)
.drop(F.col('count'))
)
(
df
.join(F.broadcast(duplicated_keys), primary_key)
).show()
----- ----- ----- -----
|col_1|col_2|col_3|count|
----- ----- ----- -----
| A| A| 1| 3|
| A| A| 2| 3|
| A| A| 3| 3|
| A| B| 4| 2|
| A| B| 5| 2|
----- ----- ----- -----