I have below df. I want to group the dataframe by company and date and for such grouped subsets filter row based on category prioritizing QF if available, if not SAF and if not then AF. I am trying to assign ranks using window function but maybe there is an easier way.
company date value category
------------------------------------
xyz 31-12-2020 12 AF
xyz 31-12-2020 10 SAF
xyz 31-12-2020 11 QF
xyz 30-06-2020 14 AF
xyz 30-06-2020 16 SAF
xyz 30-09-2020 13 SAF
xyz 31-03-2019 20 AF
Expected output:
company date value category
------------------------------------
xyz 31-12-2020 11 QF
xyz 30-06-2020 16 SAF
xyz 30-09-2020 13 SAF
xyz 31-03-2019 20 AF
CodePudding user response:
We can assign a rank to the categories using when().otherwise()
and retain the records that have the min rank in the group.
data_sdf. \
withColumn('cat_rank',
func.when(func.col('cat') == 'QF', func.lit(1)).
when(func.col('cat') == 'SAF', func.lit(2)).
when(func.col('cat') == 'AF', func.lit(3))
). \
withColumn('min_cat_rank',
func.min('cat_rank').over(wd.partitionBy('company', 'dt'))
). \
filter(func.col('min_cat_rank').isNotNull()). \
filter(func.col('min_cat_rank') == func.col('cat_rank')). \
drop('cat_rank', 'min_cat_rank'). \
show()
# ------- ---------- --- ---
# |company| dt|val|cat|
# ------- ---------- --- ---
# | xyz|30-09-2020| 13|SAF|
# | xyz|30-06-2020| 16|SAF|
# | xyz|31-03-2019| 20| AF|
# | xyz|31-12-2020| 11| QF|
# ------- ---------- --- ---
CodePudding user response:
Assuming, that there are only a limited amount of categories and that there are no duplicated entries for each categories I would suggest to map the categories to integers to which you can order them. Afterwards you can simply partition, sort and pick the first entry of each partition.
df = df.withColumn('mapping',
f.when(f.col('category') == 'QF', f.lit('1')).otherwise(
f.when(f.col('category') == 'SAF', f.lit('2')).otherwise(
f.when(f.col('category') == 'AF', f.lit('3')).otherwise(f.lit(None)))))
w = Window.partitionBy('date').orderBy(f.col('mapping'))
df.withColumn('row', f.row_number().over(w))\
.filter(f.col('row') == 1)\
.drop('row', 'mapping')\
.show()
CodePudding user response:
Supposing there can be multiple values for the same category in a combination of company
and date
, and that we want to keep the maximum value
for the preferred category, here is a solution with two window functions:
import pyspark.sql.functions as F
from pyspark.sql.window import Window
w_company_date = Window.partitionBy('company', 'date')
w_company_date_category = Window.partitionBy('company', 'date', 'category')
df = (df
.withColumn('priority', F.when(F.col('category') == 'QF', 1)
.when(F.col('category') == 'SAF', 2)
.when(F.col('category') == 'AF', 3)
.otherwise(None))
.withColumn('top_choice', F.when((F.col('priority') == F.min('priority').over(w_company_date))
& (F.col('value') == F.max('value').over(w_company_date_category)), 1)
.otherwise(0))
.filter(F.col('top_choice') == 1)
.drop('priority', 'top_choice')
)
df.show()
------- ---------- ----- --------
|company| date|value|category|
------- ---------- ----- --------
| xyz|2020-03-31| 20| AF|
| xyz|2020-06-30| 16| SAF|
| xyz|2020-09-30| 13| SAF|
| xyz|2020-12-31| 11| QF|
------- ---------- ----- --------