I have a dataframe with department, item id and count of those ids. There are 127 departments and I want to get the top 10 items for each department and list them. That means based on the item count, I want to list the top 10 items for each each department separately. I have been trying to do this using groupBy and agg.max but was not able to. Example of the dataframe is listed below.
Department | Item id | count |
---|---|---|
A | 101 | 10 |
B | 102 | 5 |
A | 104 | 12 |
C | 101 | 5 |
D | 104 | 14 |
C | 108 | 10 |
CodePudding user response:
The solution is based on the row_number()
windows function.
- In this demo I returned the top 3 records per department. Feel free to change it to 10.
qualify
is new to Spark SQL. If your Spark version doesn't support it, then wrapping query is needed and the filter will done using WHERE clause on the outer query.- I added the
Item id
to the ORDER BY in order to breakcount
ties in a deterministic way.
Data Sample Creation
df = spark.sql('''select char(ascii('A') d.i) as Department, 100 i.i as `Item id`, int(rand()*100) as count from range(3) as d(i), range(7) as i(i) order by 1,3 desc''')
df.show(999)
---------- ------- -----
|Department|Item id|count|
---------- ------- -----
| A| 103| 89|
| A| 106| 68|
| A| 104| 54|
| A| 100| 52|
| A| 105| 50|
| A| 102| 40|
| A| 101| 30|
| B| 104| 94|
| B| 101| 87|
| B| 106| 74|
| B| 105| 66|
| B| 102| 48|
| B| 100| 32|
| B| 103| 14|
| C| 105| 95|
| C| 103| 94|
| C| 102| 90|
| C| 104| 82|
| C| 100| 9|
| C| 101| 6|
| C| 106| 3|
---------- ------- -----
Spark SQL Solution
df.createOrReplaceTempView('t')
sql_query = '''
select *
from t
qualify row_number() over (partition by Department order by count desc, `Item id`) <= 3
'''
spark.sql(sql_query).show(999)
---------- ------- -----
|Department|Item id|count|
---------- ------- -----
| A| 103| 89|
| A| 106| 68|
| A| 104| 54|
| B| 104| 94|
| B| 101| 87|
| B| 106| 74|
| C| 105| 95|
| C| 103| 94|
| C| 102| 90|
---------- ------- -----
pyspark Solution
import pyspark.sql.functions as F
import pyspark.sql.window as W
(df.withColumn('rn', F.row_number().over(W.Window.partitionBy('Department').orderBy(df['count'].desc(),df['Item id'])))
.where('rn <= 3')
.drop('rn')
.show(999)
)
---------- ------- -----
|Department|Item id|count|
---------- ------- -----
| A| 103| 89|
| A| 106| 68|
| A| 104| 54|
| B| 104| 94|
| B| 101| 87|
| B| 106| 74|
| C| 105| 95|
| C| 103| 94|
| C| 102| 90|
---------- ------- -----