Home > Net >  PySpark drop Duplicates and Keep Rows with highest value in a column
PySpark drop Duplicates and Keep Rows with highest value in a column

Time:12-28

I have the following Spark dataset:

id    col1    col2    col3    col4
1      1        5       2      3
1      1        0       2      3
2      3        1       7      7
3      6        1       3      3
3      6        5       3      3

I would like to drop the duplicates in the columns subset ['id,'col1','col3','col4'] and keep the duplicate rows with the highest value in col2. This is what the result should look like:

id    col1    col2    col3    col4
1      1        5       2      3
2      3        1       7      7
3      6        5       3      3

How can I do that in PySpark?

CodePudding user response:

group by and get the max of col2 ?

df = df.groupby(['id','col1','col3','col4']).max('col2')

CodePudding user response:

Another way, compute the max, filter where max=col2. This allows you to keep multiple instances where the condition is true

df.withColumn('max',max('col2').over(Window.partitionBy('id'))).where(col('col2')==col('max')).show()

CodePudding user response:

If you are more comfortable with SQL syntax rather than the PySpark Dataframe apis, you can do this approach:

Create dataframe (optional since you already have data)

from pyspark.sql.types import StructType,StructField, IntegerType

data = [
  (1,      1,        5,       2,      3),
  (1,      1,        0,       2,      3),
  (2,      3,        1,       7,      7),
  (3,      6,        1,       3,      3),
  (3,      6,        5,       3,      3),
]

schema = StructType([ \
    StructField("id",IntegerType()), \
    StructField("col1",IntegerType()), \
    StructField("col2",IntegerType()), \
    StructField("col3", IntegerType()), \
    StructField("col4", IntegerType()), \
  ])

df = spark.createDataFrame(data=data,schema=schema)
df.show()

Then create a view of the dataframe to run sql queries. Below creates a new temporary view of the dataframe called "tbl".

# create view from df called "tbl"
df.createOrReplaceTempView("tbl")

Finally write a SQL query with the view. Here we group by id, col1, col3, and col4, and then select rows with max value of col2.

# query to group by id,col1,col3,col4 and select max col2
my_query = """
select 
  id, col1, max(col2) as col2, col3, col4
from tbl
group by id, col1, col3, col4
"""

new_df = spark.sql(my_query)
new_df.show()

Final output:

 --- ---- ---- ---- ---- 
| id|col1|col2|col3|col4|
 --- ---- ---- ---- ---- 
|  1|   1|   5|   2|   3|
|  2|   3|   1|   7|   7|
|  3|   6|   5|   3|   3|
 --- ---- ---- ---- ---- 

  • Related