I have a Dataframe with the list of state and salary list in the state.I need to group by state and find out how many entries are there in each of the salary range(there are 3 salary ranges in total) create a Dataframe and sort the result based on state name . Is there any function in Spark which will achieve this.
Sample input
State salary
------ ------
NY 6
WI 15
NY 11
WI 2
MI 20
NY 15
Result expected is
State group1 group2 group3
MI 0 0 1
NY 0 1 2
WI 1 0 1
Where
- Group1 is count of salary > 0 and <= 5
- Group2 is count of salary > 5 and <=10
- Group3 is count of salary >10 and <=20
basically looking from something like in scala spark
df.groupBy('STATE').agg(count('*') as group1).where('SALARY' >0 and 'SALARY' <=5)
.agg(count('*') as group2).where('SALARY' >5 and 'SALARY' <=10)
.agg(count('*') as group3).where('SALARY' >10 and 'SALARY' <=20)```
CodePudding user response:
You can specify condition on which you want to count
/sum
in aggregate.
Example:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
spark = SparkSession.builder.getOrCreate()
data = [
{"State": "NY", "Salary": 6},
{"State": "WI", "Salary": 15},
{"State": "NY", "Salary": 11},
{"State": "WI", "Salary": 2},
{"State": "MI", "Salary": 20},
{"State": "NY", "Salary": 15},
]
df = spark.createDataFrame(data=data)
cnt_cond = lambda cond: F.sum(F.when(cond, 1).otherwise(0))
df = df.groupBy("State").agg(
cnt_cond((F.col("Salary") > 0) & (F.col("Salary") <= 5)).alias("group_1"),
cnt_cond((F.col("Salary") > 5) & (F.col("Salary") <= 10)).alias("group_2"),
cnt_cond((F.col("Salary") > 10) & (F.col("Salary") <= 20)).alias("group_3"),
)
Here sum
is the same as count
since it checks the condition and returns 1
if condition is met, otherwise 0
.
Result:
----- ------- ------- -------
|State|group_1|group_2|group_3|
----- ------- ------- -------
|NY |0 |1 |2 |
|WI |1 |0 |1 |
|MI |0 |0 |1 |
----- ------- ------- -------
CodePudding user response:
You can use the expression composed of sum
and case
functions.
data = [
('NY', 6),
('WI', 15),
('NY', 11),
('WI', 2),
('MI', 20),
('NY', 15)
]
df = spark.createDataFrame(data, ['State', 'salary'])
df = df.groupBy('State').agg(F.expr('sum(case when salary>0 and salary<=5 then 1 else 0 end)').alias('group1'),
F.expr('sum(case when salary>5 and salary<=10 then 1 else 0 end)').alias('group2'),
F.expr('sum(case when salary>10 and salary<=20 then 1 else 0 end)').alias('group3'))
df.show(truncate=False)