Home > database >  Grouping and sum of columns and eliminate duplicates in PySpark
Grouping and sum of columns and eliminate duplicates in PySpark

Time:02-16

I have a data frame like below in pyspark

df = spark.createDataFrame(
[
('14_100_00','A',25,0),
('14_100_00','A',0,24),
('15_100_00','A',20,1),
('150_100','C',21,0),
('16','A',0,20),
('16','A',20,0)],("rust", "name", "value_1","value_2" ))

df.show()
 --------- ---- ------- ------- 
|     rust|name|value_1|value_2|
 --------- ---- ------- ------- 
|14_100_00|   A|     25|      0|
|14_100_00|   A|      0|     24|
|15_100_00|   A|     20|      1|
|  150_100|   C|     21|      0|
|       16|   A|      0|     20|
|       16|   A|     20|      0|
 --------- ---- ------- ------- 

I am trying to update the value_1 and value_2 columns based on below conditions

  1. when rust and name columns are same then sum of value_1 as value_1 for that group
  2. when rust and name columns are same then sum of value_2 as value_2 for that group

Expected result:

 --------- ---- ------- ------- 
|     rust|name|value_1|value_2|
 --------- ---- ------- ------- 
|14_100_00|   A|     25|     24|
|15_100_00|   A|     20|      1|
|  150_100|   C|     21|      0|
|       16|   A|     20|     20|
 --------- ---- ------- ------- 

I have tried this:

df1 = df.withColumn("VALUE_1", f.sum("VALUE_1").over(Window.partitionBy("rust", "name"))).withColumn("VALUE_2", f.sum("VALUE_2").over(Window.partitionBy("rust", "name")))
df1.show()
 --------- ---- ------- ------- 
|     rust|name|VALUE_1|VALUE_2|
 --------- ---- ------- ------- 
|  150_100|   C|     21|      0|
|       16|   A|     20|     20|
|       16|   A|     20|     20|
|14_100_00|   A|     25|     24|
|14_100_00|   A|     25|     24|
|15_100_00|   A|     20|      1|
 --------- ---- ------- ------- 

Is there a better way to achieve this without having duplicates?

CodePudding user response:

Use groupBy instead of window functions:

df1 = df.groupBy("rust", "name").agg(
    F.sum("value_1").alias("value_1"),
    F.sum("value_2").alias("value_2"),
)
df1.show()
# --------- ---- ------- ------- 
#|     rust|name|value_1|value_2|
# --------- ---- ------- ------- 
#|14_100_00|   A|     25|     24|
#|15_100_00|   A|     20|      1|
#|  150_100|   C|     21|      0|
#|       16|   A|     20|     20|
# --------- ---- ------- ------- 
  • Related