Home > Software engineering >  Groupby and get just top 50% record based on one column pyspark
Groupby and get just top 50% record based on one column pyspark

Time:11-23

I have a dataframe like this:

id  item_id   score
1      6       1.1
2      6        1
3      6       1.4
7      6       1.3
8      2       1.2
9      2       1.8
1      4        2
10     4       1.1
2      4       1.9
8      4       1.2
.      .        .

Where combination of column id and item_id is primary key, but both will have duplicates as well.

Total unique id:       67689
Total unique item_id: 123123
Total records:        8334072747 (67689*123123)

Now I want to drop 50% of the data based on score but keeping all unique values from column item_id. For eg:

Let's say if I have 10 records with same item_id, so I want to drop 50% records with lowest score. So my unique item_id will still remain the same but I'll lose some id's. So basically for each item_id I'll have 50% of the original record.

Expected Output:

id  item_id   score
3      6       1.4
7      6       1.3
9      2       1.8
1      4        2
2      4       1.9
.      .        .

Try:

I can use window function over item column but I'm not sure how can I filter later based on percentage instead of value.

window = Window.partitionBy(df['item_id']).orderBy(df['score'].desc())

df.select('*', row_number().over(window).alias('rank')) 
  .filter(col('rank') <= 2) 

CodePudding user response:

this should working using the row_number() and count() window functions. take the count() and divide by 2.

updated filter to handle case where there's only one record. there's a case of how do you want to handle odd record counts. for instance 50% of 3 records is 1.5..you can set row_num_val as a whole number by taking the ceiling or the floor of the decimal

from pyspark.sql import functions as F
from pyspark.sql.types import *
from pyspark.sql import Window

df = spark.createDataFrame(
    [
        (1, 6, 1.1),
        (2, 6, 1.0),
        (3, 6, 1.4),
        (7, 6, 1.3),
        (8, 2, 1.2),
        (9, 2, 1.8),
        (1, 4, 2.0),
        (10, 4, 1.1),
        (2, 4, 1.9),
        (8, 4, 1.2),
        
       
    ],
    ["id", "item_id", "score"],
)

df_cnt_window = Window.partitionBy(
    "item_id",
)

df_row_window = Window.partitionBy(
    "item_id",
).orderBy(F.col("score").desc())

df = (
    df
    .withColumn(
        "cnt",
        F.count("*").over(df_cnt_window),
    )
    .withColumn(
        "row_num",
        F.row_number().over(df_row_window),
    )
    .withColumn(
        "row_num_val", (F.col("cnt") / 2).cast(IntegerType())
    )
    .filter( (F.col("row_num") <= F.col("row_num_val")) |  (F.col("row_num_val") == 0)   )
    .drop(F.col("row_num"))
    .drop(F.col("row_num_val"))
    .drop(F.col("cnt"))
    
)

df.show()

output:

 --- ------- ----- 
| id|item_id|score|
 --- ------- ----- 
|  3|      6|  1.4|
|  7|      6|  1.3|
|  9|      2|  1.8|
|  1|      4|  2.0|
|  2|      4|  1.9|
 --- ------- ----- 
  • Related