Home > Software engineering >  Is there a Spark function to achieve the groupby then filter and then aggregate
Is there a Spark function to achieve the groupby then filter and then aggregate

Time:12-08

I have a Dataframe with the list of state and salary list in the state.I need to group by state and find out how many entries are there in each of the salary range(there are 3 salary ranges in total) create a Dataframe and sort the result based on state name . Is there any function in Spark which will achieve this.

Sample input 

State  salary
------ ------
NY      6
WI      15
NY      11
WI      2
MI      20
NY      15 
 
Result expected is

State    group1   group2  group3
 MI         0       0       1  
 NY         0       1       2
 WI         1       0       1

Where

  • Group1 is count of salary > 0 and <= 5
  • Group2 is count of salary > 5 and <=10
  • Group3 is count of salary >10 and <=20

basically looking from something like

df.groupBy('STATE').agg(count('*') as group1).where('SALARY' >0 and 'SALARY' <=5)
.agg(count('*') as group2).where('SALARY' >5 and 'SALARY' <=10)
.agg(count('*') as group3).where('SALARY' >10 and 'SALARY' <=20)```

CodePudding user response:

You can specify condition on which you want to count/sum in aggregate.

Example:

from pyspark.sql import SparkSession
from pyspark.sql import functions as F

spark = SparkSession.builder.getOrCreate()
data = [
    {"State": "NY", "Salary": 6},
    {"State": "WI", "Salary": 15},
    {"State": "NY", "Salary": 11},
    {"State": "WI", "Salary": 2},
    {"State": "MI", "Salary": 20},
    {"State": "NY", "Salary": 15},
]
df = spark.createDataFrame(data=data)
cnt_cond = lambda cond: F.sum(F.when(cond, 1).otherwise(0))
df = df.groupBy("State").agg(
    cnt_cond((F.col("Salary") > 0) & (F.col("Salary") <= 5)).alias("group_1"),
    cnt_cond((F.col("Salary") > 5) & (F.col("Salary") <= 10)).alias("group_2"),
    cnt_cond((F.col("Salary") > 10) & (F.col("Salary") <= 20)).alias("group_3"),
)

Here sum is the same as count since it checks the condition and returns 1 if condition is met, otherwise 0.

Result:

 ----- ------- ------- -------                                                  
|State|group_1|group_2|group_3|
 ----- ------- ------- ------- 
|NY   |0      |1      |2      |
|WI   |1      |0      |1      |
|MI   |0      |0      |1      |
 ----- ------- ------- ------- 

CodePudding user response:

You can use the expression composed of sum and case functions.

data = [
    ('NY', 6),
    ('WI', 15),
    ('NY', 11),
    ('WI', 2),
    ('MI', 20),
    ('NY', 15)
]
df = spark.createDataFrame(data, ['State', 'salary'])
df = df.groupBy('State').agg(F.expr('sum(case when salary>0 and salary<=5 then 1 else 0 end)').alias('group1'),
                             F.expr('sum(case when salary>5 and salary<=10 then 1 else 0 end)').alias('group2'),
                             F.expr('sum(case when salary>10 and salary<=20 then 1 else 0 end)').alias('group3'))
df.show(truncate=False)
  • Related