Currently I'm gathering the top 5 most frequent values with a UDF.
The goal is to achieve the same result without using UDF
and have the most efficient solution (avoid groupBy in loops).
Here's the code I'm using to have the result :
from pyspark.sql import functions as F
df = df.select('A', 'B', ...)
@F.udf
def get_top_5_udf(x)
from collections import Counter
return [elem[0] for elem in Counter(x).most_common(5)]
agg_expr = [get_top_5_udf(F.collect_list(col)).alias(col) for col in df.columns]
df_top5 = df.agg(*agg_expr)
The result looks like the following :
# result
# ----------------- -------------- ---------------
#| A | B | ... |
# ----------------- -------------- ---------------
#| [1, 2, 3, 4, 5] | [...] | ... |
# ----------------- -------------- ---------------
CodePudding user response:
You can try using count over window partitioned by each column before aggregating:
from pyspark.sql import functions as F, Window
result = df.select(*[
F.struct(
F.count(c).over(Window.partitionBy(c)).alias("cnt"),
F.col(c).alias("val")
).alias(c) for c in df.columns
]).agg(*[
F.slice(
F.expr(f"transform(sort_array(collect_set({c}), false), x -> x.val)"),
1, 5
).alias(c) for c in df.columns
])
result.show()