I want to group on multiple columns and then aggregate various columns by user-defined-functions (udf) that calculates mode for each of the columns. I demonstrate my problem by this sample code:
import pandas as pd
from pyspark.sql.functions import col, udf
from pyspark.sql.types import StringType, IntegerType
df = pd.DataFrame(columns=['A', 'B', 'C', 'D'])
df["A"] = ["Mon", "Mon", "Mon", "Fri", "Fri", "Fri", "Fri"]
df["B"] = ["Feb", "Feb", "Feb", "May", "May", "May", "May"]
df["C"] = ["x", "y", "y", "m", "n", "r", "r"]
df["D"] = [3, 3, 5, 1, 1, 1, 9]
df_sdf = spark.createDataFrame(df)
df_sdf.show()
--- --- --- ---
| A| B| C| D|
--- --- --- ---
|Mon|Feb| x| 3|
|Mon|Feb| y| 3|
|Mon|Feb| y| 5|
|Fri|May| m| 1|
|Fri|May| n| 1|
|Fri|May| r| 1|
|Fri|May| r| 9|
--- --- --- ---
# Custom mode function to get mode value for string list and integer list
def custom_mode(lst): return(max(lst, key=lst.count))
custom_mode_str = udf(custom_mode, StringType())
custom_mode_int = udf(custom_mode, IntegerType())
grp_columns = ["A", "B"]
df_sdf.groupBy(grp_columns).agg(custom_mode_str(col("C")).alias("C"), custom_mode_int(col("D")).alias("D")).distinct().show()
However, I am getting the following error on last line of above code:
AnalysisException: expression '`C`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.;;
The expected output for this code is:
--- --- --- ---
| A| B| C| D|
--- --- --- ---
|Mon|Feb| y| 3|
|Fri|May| r| 1|
--- --- --- ---
I searched a lot but couldn't find something similar to this problem in pyspark. Thanks for your time.
CodePudding user response:
Your UDF requires a list
but you're providing a spark dataframe's column. You can pass a list to the function which will generate your desired result.
sdf.groupBy(['A', 'B']). \
agg(custom_mode_str(func.collect_list('C')).alias('C'),
custom_mode_int(func.collect_list('D')).alias('D')
). \
show()
# --- --- --- ---
# | A| B| C| D|
# --- --- --- ---
# |Mon|Feb| y| 3|
# |Fri|May| r| 1|
# --- --- --- ---
The collect_list()
is the key here as it will generate a list which will work with your UDF. See collection outputs below.
sdf.groupBy(['A', 'B']). \
agg(func.collect_list('C').alias('C_collected'),
func.collect_list('D').alias('D_collected')
). \
show()
# --- --- ------------ ------------
# | A| B| C_collected| D_collected|
# --- --- ------------ ------------
# |Mon|Feb| [x, y, y]| [3, 3, 5]|
# |Fri|May|[m, n, r, r]|[1, 1, 1, 9]|
# --- --- ------------ ------------