I understand collect_set can have a random order. Is there a different method of ordering a collect_set by count? I want to have an array of the most popular items for a single column based on a group by of a separate id column. Would you have a collect_list and then run the count on that?
CodePudding user response:
If I understand you correctly, you want to do popularity ranking analysis.you need to use collect_list to preserve duplicate values.
from collections import Counter
from pyspark.sql import SparkSession
spark = SparkSession.builder.enableHiveSupport().getOrCreate()
def elem_cnt(arr):
return ['{}({})'.format(*i) for i in Counter(arr).most_common()]
spark.udf.register('elem_cnt_udf', elem_cnt)
data = [
('AC Milan', 'Ronaldo Luiz'),
('AC Milan', 'Paolo Maldini'),
('AC Milan', 'Kaká'),
('AC Milan', 'Ronaldo Luiz'),
('AC Milan', 'Andriy Shevchenko'),
('AC Milan', 'Van Basten'),
('AC Milan', 'Ronaldo Luiz'),
('AC Milan', 'Andriy Shevchenko'),
('AC Milan', 'Van Basten'),
('Milan', 'Ronaldo Luiz'),
('Milan', 'Paolo Maldini'),
('Milan', 'Ronaldo Luiz'),
('Milan', 'Van Basten')
]
schema = """
id string,name string
"""
df = spark.createDataFrame(data, schema)
df.createOrReplaceTempView('tmp')
rank_sql = """
select id,elem_cnt_udf(collect_list(name)) rank from tmp
group by id
"""
rank_df = spark.sql(rank_sql)
rank_df.show(truncate=False)
CodePudding user response:
No, there is no method to order collect_set
by count, as collect aggregate methods don't count items, information is not available to sort items.
So, since Spark 3.1 and greater, and given a dataframe
with two columns id
and item
, you can:
- perform
count
over a groupBy on columnsid
anditems
- collect
(count, item)
couples to an array withcollect_list
andstruct
. Note: you can usecollect_set
here instead ofcollect_list
, but it is useless as we are sure that each element of(count, item)
is unique - use
sort_array
to sort your array by descending count - map your array with
transform
to dropcount
.
Which can be translated to code as follow:
from pyspark.sql import functions as F
final_df = dataframe.groupBy('id', 'item').count() \
.groupBy('id') \
.agg(
F.transform(
F.sort_array(
F.collect_list(F.struct("count", "item")),
asc=False
),
lambda x: x.getItem('item')
).alias('popular_items')
)
Note: if your spark version lower than 3.1 but greater than 1.6, you can replace transform
with withColumn
as follow:
from pyspark.sql import functions as F
final_df = dataframe.groupBy('id', 'item').count() \
.groupBy('id') \
.agg(F.sort_array(F.collect_list(F.struct("count", "item")), asc=False).alias('popular_items')) \
.withColumn("popular_items", F.col('popular_items.item'))
Example
With the following input dataframe:
--- -----
|id |item |
--- -----
|1 |item1|
|1 |item2|
|1 |item2|
|1 |item2|
|1 |item3|
|2 |item3|
|2 |item3|
|2 |item1|
|3 |item1|
|3 |item1|
--- -----
You get the following output:
--- ---------------------
|id |popular_items |
--- ---------------------
|1 |[item2, item3, item1]|
|3 |[item1] |
|2 |[item3, item1] |
--- ---------------------