Home > Back-end >  PySpark dataframe How to filter data?
PySpark dataframe How to filter data?

Time:03-14

I have a dataframe with department, item id and count of those ids. There are 127 departments and I want to get the top 10 items for each department and list them. That means based on the item count, I want to list the top 10 items for each each department separately. I have been trying to do this using groupBy and agg.max but was not able to. Example of the dataframe is listed below.

Department Item id count
A 101 10
B 102 5
A 104 12
C 101 5
D 104 14
C 108 10

CodePudding user response:

The solution is based on the row_number() windows function.

  • In this demo I returned the top 3 records per department. Feel free to change it to 10.
  • qualify is new to Spark SQL. If your Spark version doesn't support it, then wrapping query is needed and the filter will done using WHERE clause on the outer query.
  • I added the Item id to the ORDER BY in order to break count ties in a deterministic way.

Data Sample Creation

df = spark.sql('''select char(ascii('A')   d.i) as Department, 100   i.i as `Item id`, int(rand()*100) as count from range(3) as d(i), range(7) as i(i) order by 1,3 desc''')

df.show(999)

 ---------- ------- ----- 
|Department|Item id|count|
 ---------- ------- ----- 
|         A|    103|   89|
|         A|    106|   68|
|         A|    104|   54|
|         A|    100|   52|
|         A|    105|   50|
|         A|    102|   40|
|         A|    101|   30|
|         B|    104|   94|
|         B|    101|   87|
|         B|    106|   74|
|         B|    105|   66|
|         B|    102|   48|
|         B|    100|   32|
|         B|    103|   14|
|         C|    105|   95|
|         C|    103|   94|
|         C|    102|   90|
|         C|    104|   82|
|         C|    100|    9|
|         C|    101|    6|
|         C|    106|    3|
 ---------- ------- ----- 

Spark SQL Solution

df.createOrReplaceTempView('t')

sql_query = '''
select  *
from    t
qualify row_number() over (partition by Department order by count desc, `Item id`) <= 3
'''

spark.sql(sql_query).show(999)
  
 ---------- ------- ----- 
|Department|Item id|count|
 ---------- ------- ----- 
|         A|    103|   89|
|         A|    106|   68|
|         A|    104|   54|
|         B|    104|   94|
|         B|    101|   87|
|         B|    106|   74|
|         C|    105|   95|
|         C|    103|   94|
|         C|    102|   90|
 ---------- ------- ----- 

pyspark Solution

import pyspark.sql.functions as F
import pyspark.sql.window as W

(df.withColumn('rn', F.row_number().over(W.Window.partitionBy('Department').orderBy(df['count'].desc(),df['Item id'])))
 .where('rn <= 3')
 .drop('rn')
 .show(999)
)

 ---------- ------- ----- 
|Department|Item id|count|
 ---------- ------- ----- 
|         A|    103|   89|
|         A|    106|   68|
|         A|    104|   54|
|         B|    104|   94|
|         B|    101|   87|
|         B|    106|   74|
|         C|    105|   95|
|         C|    103|   94|
|         C|    102|   90|
 ---------- ------- ----- 
  • Related