Home > database >  PySpark equivalent to a groupby categories in pandas?
PySpark equivalent to a groupby categories in pandas?

Time:10-26

On Pandas, we can group by a categorical series and then when aggregating, it displays all the categories, regardless it contains any records or not.

import pandas as pd

df = pd.DataFrame({"Age": [12, 20, 40, 60, 72]}, dtype=np.float64)
cuts = pd.cut(df.Age, bins=[0, 11, 30, 60])
df.Age.groupby(cuts).agg(mean="mean", occurrences="size")

#           mean  occurrences
# Age                        
# (0, 11]    NaN            0
# (11, 30]  16.0            2
# (30, 60]  50.0            2

As you can see, the first bin is displayed even though it does not appear in the dataset. How could I achieve the same behaviour on PySpark?

CodePudding user response:

The following is quite much, but I'm not aware of any nicer method.

from pyspark.sql import functions as F
df = spark.createDataFrame([(12,), (20,), (40,), (60,), (72,)], ['Age'])

bins = [0, 11, 30, 60]

conds = F
for i, b in enumerate(bins[1:]):
    conds = conds.when(F.col('id') <= b, f'({bins[i]}, {b}]')
df2 = spark.range(1, bins[-1] 1).withColumn('_grp', conds)

df = df2.join(df, df2.id == df.Age, 'left')
df = df.groupBy(F.col('_grp').alias('Age')).agg(
    F.mean('Age').alias('mean'),
    F.count('Age').alias('occurrences'),
)

df.show()
#  -------- ---- ----------- 
# |     Age|mean|occurrences|
#  -------- ---- ----------- 
# |(11, 30]|16.0|          2|
# | (0, 11]|null|          0|
# |(30, 60]|50.0|          2|
#  -------- ---- ----------- 
  • Related