Home > other >  How to use UDFs with pandas on pyspark groupby?
How to use UDFs with pandas on pyspark groupby?

Time:10-28

I am struggling to use pandas UDFs on pandas on pyspark. Can you please help me understand how this is to be achieved? Below is my attempt:

import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf
from pyspark import pandas as ps
spark = SparkSession.builder.getOrCreate()
df = ps.DataFrame({'A': 'a a b'.split(),
                   'B': [1, 2, 3],
                   'C': [4, 6, 5]}, columns=['A', 'B', 'C'])
@pandas_udf('float')
def agg_a(x):
    return (x**2).mean()
@pandas_udf('float')
def agg_b(x):
    return x.mean()
spark.udf.register('agg_a_',agg_a)
spark.udf.register('agg_b_',agg_b)
df_means = df.groupby('A')
dfout=df_means.agg({'B':'agg_a_','C':'agg_b_'})

This results in an exception I am struggling to understand:

AnalysisException: expression 'B' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.;
Aggregate [__index_level_0__#14], [__index_level_0__#14, agg_a_(B#2L) AS B#15, agg_b_(C#3L) AS C#16]
 - Project [A#1 AS __index_level_0__#14, A#1, B#2L, C#3L]
    - Project [__index_level_0__#0L, A#1, B#2L, C#3L, monotonically_increasing_id() AS __natural_order__#8L]
       - LogicalRDD [__index_level_0__#0L, A#1, B#2L, C#3L], false

I tried using udf instead of pandas_udf but, that too fails with same exception

I tried using groupby with UDFs only on one column as well but that too fails:

import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
from pyspark import pandas as ps
spark = SparkSession.builder.getOrCreate()
df = ps.DataFrame({'A': 'a a b'.split(),
                   'B': [1, 2, 3],
                   'C': [4, 6, 5]}, columns=['A', 'B', 'C'])
@udf('float')
def agg_a(x):
    return (x**2).mean()
@udf('float')
def agg_b(x):
    return x.mean()
spark.udf.register('agg_a_',agg_a)
spark.udf.register('agg_b_',agg_b)
df_means = df.groupby('A')['B']
dfout=df_means.agg('agg_a_')

output:

PandasNotImplementedError: The method `pd.groupby.GroupBy.agg()` is not implemented yet.

This is not true I guess. I can use groupby if I don't use UDFs and use already defined functions like 'min', 'max'.

I tried using without specifying different UDFs columnwise and that too failed:

import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
from pyspark import pandas as ps
spark = SparkSession.builder.getOrCreate()
df = ps.DataFrame({'A': 'a a b'.split(),
                   'B': [1, 2, 3],
                   'C': [4, 6, 5]}, columns=['A', 'B', 'C'])
@udf('float')
def agg_a(x):
    return (x**2).mean()
@udf('float')
def agg_b(x):
    return x.mean()
spark.udf.register('agg_a_',agg_a)
spark.udf.register('agg_b_',agg_b)
df_means = df.groupby('A')
dfout=df_means.agg('agg_a_')

output:

AnalysisException: expression 'B' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.;
Aggregate [__index_level_0__#14], [__index_level_0__#14, agg_a_(B#2L) AS B#15, agg_a_(C#3L) AS C#16]
 - Project [A#1 AS __index_level_0__#14, A#1, B#2L, C#3L]
    - Project [__index_level_0__#0L, A#1, B#2L, C#3L, monotonically_increasing_id() AS __natural_order__#8L]
       - LogicalRDD [__index_level_0__#0L, A#1, B#2L, C#3L], false

CodePudding user response:

According to GroupedData.agg documentation, you need to define your pandas_udf with PandasUDFType. And if you need an aggregation then it would be PandasUDFType.GROUPED_AGG.

from pyspark.sql.functions import pandas_udf, PandasUDFType

@pandas_udf('float', PandasUDFType.GROUPED_AGG)
def agg_a(x):
    return (x**2).mean()

@pandas_udf('float', PandasUDFType.GROUPED_AGG)
def agg_b(x):
    return x.mean()

spark.udf.register('agg_a_',agg_a)
spark.udf.register('agg_b_',agg_b)

df.groupby('A').agg({'B':'agg_a_','C':'agg_b_'}).show()

#  --- --------- --------- 
# |  A|agg_a_(B)|agg_b_(C)|
#  --- --------- --------- 
# |  b|      9.0|      5.0|
# |  a|      2.5|      5.0|
#  --- --------- --------- 
  • Related