Home > Enterprise >  Remove all rows (for a given column value) from a dataframe if column max value is less than a defin
Remove all rows (for a given column value) from a dataframe if column max value is less than a defin

Time:03-09

Apologies if the question heading is a bit confusing. I am new to pyspark and am dealing with the following problem:

Let's say I have a dataframe with date, product and total_orders as three columns and I have this dataframe for a period of 3 days. So something like

date           product      orders

2022-01-01      whisky        11
2022-01-01      rum           100
2022-01-01      bourbon       5
2022-01-02      whisky        20
2022-01-02      rum           150
2022-01-02      bourbon       7 
2022-01-03      whisky        30
2022-01-03      rum           110
2022-01-03      bourbon       3

I want to filter out any product whose maximum number of orders are less than 10. So in the case of dataframe above all the rows containing bourbon as a product will be filtered out as the max(orders of bourbon) < 10.

Output:

    date       product      orders

2022-01-01      whisky        11
2022-01-01      rum           100
2022-01-02      whisky        20
2022-01-02      rum           150
2022-01-03      whisky        30
2022-01-03      rum           110

What is the best way to go about it? I have been looking int Window function in pyspark but have not been able to get it right.

I have created a windowspec like this

windowSpec = Window.partitionBy(groupedDf['product']).orderBy(groupedDf['orders'].desc())

but having trouble filtering out the dataframe rows.

CodePudding user response:

This is just the case for a window function.

from pyspark.sql.window import Window
import pyspark.sql.functions as F
window = Window.partitionBy("product").orderBy(F.col("orders").desc())
df.withColumn("rn", F.row_number().over(window)).filter("rn < 10")

CodePudding user response:

You can first find the max orders for each product, and then filter based on that value.

df = df.selectExpr('*', 'max(orders) over (partition by product) as max_orders') \
    .filter('max_orders >= 10').drop('max_orders')
df.show(truncate=False)

CodePudding user response:

PySpark DataFrame provides a method toPandas() to convert it to Python Pandas.DataFrame, then, first convert it...

df = pyspark_df.toPandas()

and you can use simple pandas conditional indexing, by invoking loc method, for example, this statement will drop any value in orders less than 10:

df = df.loc[df["orders"] > 10]
  • Related