Home > database >  Aggregate GroupBy columns with "all"-like function pyspark
Aggregate GroupBy columns with "all"-like function pyspark

Time:05-24

I have a dataframe with a primary key, date, variable, and value. I want to group by the primary key and determine if all values are equal to a provided value. Example data:

import pandas as pd
from datetime import date
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

df = pd.DataFrame({
    "pk": [1, 1, 1, 1, 2, 2, 2, 2, 3, 4],
    "date": [
        date("2022-05-06"),
        date("2022-05-13"),
        date("2022-05-06"),
        date("2022-05-06"),
        date("2022-05-14"),
        date("2022-05-15"),
        date("2022-05-05"),
        date("2022-05-05"),
        date("2022-05-11"),
        date("2022-05-12")
    ],
    "variable": [A, B, C, D, A, A, E, F, A, G],
    "value": [2, 3, 2, 2, 1, 1, 1, 1, 5, 4]
})

df = spark.createDataFrame(df)

df.show()

df1.show()

# ----- ----------- -------- ----- 
#|pk   |       date|variable|value|
# ----- ----------- -------- ----- 
#|    1| 2022-05-06|       A|    2|
#|    1| 2022-05-13|       B|    3|
#|    1| 2022-05-06|       C|    2|
#|    1| 2022-05-06|       D|    2|
#|    2| 2022-05-14|       A|    1|
#|    2| 2022-05-15|       A|    1|
#|    2| 2022-05-05|       E|    1|
#|    2| 2022-05-05|       F|    1|
#|    3| 2022-05-11|       A|    5|
#|    4| 2022-05-12|       G|    4|
# ----- ----------- -------- ----- 

So if I want to know whether, given a primary key, pk, all the values are equal to 1 (or any arbitrary Boolean test), how should I do this? I've tried performing an applyInPandas but that is not super efficient and it seems like there is probably a pretty simply method to do this.

CodePudding user response:

For Spark 3. , you could use forall function to check if all values collected by collect_list satisfy the boolean test.

import pyspark.sql.functions as F

df1 = (df
       .groupby("pk")
       .agg(F.expr("forall(collect_list(value), v -> v == 1)").alias("value"))
       )
df1.show()
#  --- ----- 
# | pk|value|
#  --- ----- 
# |  1|false|
# |  3|false|
# |  2| true|
# |  4|false|
#  --- ----- 

# or create a column using window function
df2 = df.withColumn("test", F.expr("forall(collect_list(value) over (partition by pk), v -> v == 1)"))
df2.show()
#  --- ---------- -------- ----- ----- 
# | pk|      date|variable|value| test|
#  --- ---------- -------- ----- ----- 
# |  1|2022-05-06|       A|    2|false|
# |  1|2022-05-13|       B|    3|false|
# |  1|2022-05-06|       C|    2|false|
# |  1|2022-05-06|       D|    2|false|
# |  3|2022-05-11|       A|    5|false|
# |  2|2022-05-14|       A|    1| true|
# |  2|2022-05-15|       A|    1| true|
# |  2|2022-05-05|       E|    1| true|
# |  2|2022-05-05|       F|    1| true|
# |  4|2022-05-12|       G|    4|false|
#  --- ---------- -------- ----- ----- 

You might want to put it inside a case clause to handle NULL values.

  • Related