Home > Enterprise >  How to remove duplicate records from PySpark DataFrame based on a condition?
How to remove duplicate records from PySpark DataFrame based on a condition?

Time:01-12

Assume that I have a PySpark DataFrame like below:

# Prepare Data
data = [('Italy', 'ITA'), \
    ('China', 'CHN'), \
    ('China', None), \
    ('France', 'FRA'), \
    ('Spain', None), \
    ('Taiwan', 'TWN'), \
    ('Taiwan', None)
  ]

# Create DataFrame
columns = ['Name', 'Code']
df = spark.createDataFrame(data = data, schema = columns)
df.show(truncate=False)

enter image description here

As you can see, a few countries are repeated twice (China & Taiwan in the above example). I want to delete records that satisfy the following conditions:

  1. The column 'Name' is repeated more than once

AND

  1. The column 'Code' is Null.

Note that column 'Code' can be Null for countries which are not repeated, like Spain. I want to keep those records.

The expected output will be like:

Name Code
'Italy' 'ITA'
'China' 'CHN'
'France' 'FRA'
'Spain' Null
'Taiwan' 'TWN'

In fact, I want to have one record for every country. Any idea how to do that?

CodePudding user response:

You can use window.PartitionBy to achieve your desired results:

from pyspark.sql import Window
import pyspark.sql.functions as f


df1 = df.select('Name', f.max('Code').over(Window.partitionBy('Name')).alias('Code')).distinct()
df1.show()

Output:

 ------ ---- 
|  Name|Code|
 ------ ---- 
| China| CHN|
| Spain|null|
|France| FRA|
|Taiwan| TWN|
| Italy| ITA|
 ------ ---- 

CodePudding user response:

Here is one approach :

from pyspark.sql.functions import col
df = df.dropDuplicates(subset=["Name"],keep='first')

CodePudding user response:

There will almost certainly be a cleverer way to do this, but for the sake of a lesson, what if you:

  1. made a new dataframe with just 'Name'
  2. dropped duplicates on that
  3. deleted records where Code = 'null' from initial table
  4. do a left join between new table and old table for 'Code'

I've added Australia with no country code just so you can see it works for that case as well


import pandas as pd

data = [('Italy', 'ITA'), \
    ('China', 'CHN'), \
    ('China', None), \
    ('France', 'FRA'), \
    ('Spain', None), \
    ('Taiwan', 'TWN'), \
    ('Taiwan', None), \
    ('Australia', None)
  ]

# Create DataFrame
columns = ['Name', 'Code']
df = pd.DataFrame(data = data, columns = columns)
print(df)

# get unique country names
uq_countries = df['Name'].drop_duplicates().to_frame()
print(uq_countries)

# remove None
non_na_codes = df.dropna()
print(non_na_codes)

# combine
final = pd.merge(left=uq_countries, right=non_na_codes, on='Name', how='left')
print(final)

CodePudding user response:

In order to obtain non-null rows first, use the row_number window function to group by Name column and sort the Code column. Since null is considered the smallest in Spark order by, desc mode is used. Then take the first row of each group.

df = df.withColumn('rn', F.expr('row_number() over (partition by Name order by Code desc)')).filter('rn = 1').drop('rn')
  • Related