Home > OS >  How can I count different groups and group them into one column in PySpark?
How can I count different groups and group them into one column in PySpark?

Time:08-09

In this example, I have the following dataframe:

client_id   rule_1   rule_2   rule_3   rule_4   rule_5
    1         1        0         1       0        0
    2         0        1         0       0        0
    3         0        1         1       1        0
    4         1        0         1       1        1

It shows the client_id and if he's obeying a certain rule or not.

How would I be able to count the number of clients that obey each rule, in a way that I can show all information in one dataframe?

rule    obeys    count
rule_1    0      23852
rule_1    1      95102
rule_2    0      12942
rule_2    1      45884
rule_3    0      29319
rule_3    1       9238
rule_4    0      55321
rule_4    1      23013
rule_5    0      96842
rule_5    1      86739

CodePudding user response:

The operation of moving column names to rows is called unpivoting. In Spark, it is done using stack function.

Input:

from pyspark.sql import functions as F
df = spark.createDataFrame(
    [(1, 1, 0, 1, 0, 0),
     (2, 0, 1, 0, 0, 0),
     (3, 0, 1, 1, 1, 0),
     (4, 1, 0, 1, 1, 1)],
    ["client_id", "rule_1", "rule_2", "rule_3", "rule_4", "rule_5"])

Script:

to_unpivot = [f"\'{c}\', `{c}`" for c in df.columns if c != "client_id"]
stack_str = ",".join(to_unpivot)
df = (df
    .select(F.expr(f"stack({len(to_unpivot)}, {stack_str}) as (rule, obeys)"))
    .groupBy("rule", "obeys")
    .count()
)
df.show()
#  ------ ----- ----- 
# |  rule|obeys|count|
#  ------ ----- ----- 
# |rule_1|    1|    2|
# |rule_2|    1|    2|
# |rule_1|    0|    2|
# |rule_3|    1|    3|
# |rule_2|    0|    2|
# |rule_4|    0|    2|
# |rule_3|    0|    1|
# |rule_5|    0|    3|
# |rule_5|    1|    1|
# |rule_4|    1|    2|
#  ------ ----- ----- 

CodePudding user response:

We can transpose down the rule columns and take a count. Here's an example using the sample in your question.

rule_cols = [k for k in data_sdf.columns if 'rule' in k]
# ['rule_1', 'rule_2', 'rule_3', 'rule_4', 'rule_5']

data_sdf. \
    withColumn('arr_rule_structs', 
               func.array(*[func.struct(func.lit(k).alias('key'), func.col(k).alias('val')) for k in rule_cols])
               ). \
    selectExpr('id', 'inline(arr_rule_structs)'). \
    groupBy('key', 'val'). \
    agg(func.count('id').alias('cnt')). \
    orderBy('key', 'val'). \
    show()

#  ------ --- --- 
# |   key|val|cnt|
#  ------ --- --- 
# |rule_1|  0|  2|
# |rule_1|  1|  2|
# |rule_2|  0|  2|
# |rule_2|  1|  2|
# |rule_3|  0|  1|
# |rule_3|  1|  3|
# |rule_4|  0|  2|
# |rule_4|  1|  2|
# |rule_5|  0|  3|
# |rule_5|  1|  1|
#  ------ --- --- 

Feel free to use your own field names within the struct instead of key and val.

  • Related