I have a dataframe like this:
id item_id score
1 6 1.1
2 6 1
3 6 1.4
7 6 1.3
8 2 1.2
9 2 1.8
1 4 2
10 4 1.1
2 4 1.9
8 4 1.2
. . .
Where combination of column id and item_id
is primary key, but both will have duplicates as well.
Total unique id: 67689
Total unique item_id: 123123
Total records: 8334072747 (67689*123123)
Now I want to drop 50% of the data based on score but keeping all unique values from column item_id
. For eg:
Let's say if I have 10 records with same item_id
, so I want to drop 50% records with lowest score. So my unique item_id
will still remain the same but I'll lose some id's
. So basically for each item_id I'll have 50% of the original record.
Expected Output:
id item_id score
3 6 1.4
7 6 1.3
9 2 1.8
1 4 2
2 4 1.9
. . .
Try:
I can use window function over item column but I'm not sure how can I filter later based on percentage instead of value.
window = Window.partitionBy(df['item_id']).orderBy(df['score'].desc())
df.select('*', row_number().over(window).alias('rank'))
.filter(col('rank') <= 2)
CodePudding user response:
this should working using the row_number() and count() window functions. take the count() and divide by 2.
updated filter to handle case where there's only one record.
there's a case of how do you want to handle odd record counts.
for instance 50% of 3 records is 1.5..you can set row_num_val
as a whole number by taking the ceiling or the floor of the decimal
from pyspark.sql import functions as F
from pyspark.sql.types import *
from pyspark.sql import Window
df = spark.createDataFrame(
[
(1, 6, 1.1),
(2, 6, 1.0),
(3, 6, 1.4),
(7, 6, 1.3),
(8, 2, 1.2),
(9, 2, 1.8),
(1, 4, 2.0),
(10, 4, 1.1),
(2, 4, 1.9),
(8, 4, 1.2),
],
["id", "item_id", "score"],
)
df_cnt_window = Window.partitionBy(
"item_id",
)
df_row_window = Window.partitionBy(
"item_id",
).orderBy(F.col("score").desc())
df = (
df
.withColumn(
"cnt",
F.count("*").over(df_cnt_window),
)
.withColumn(
"row_num",
F.row_number().over(df_row_window),
)
.withColumn(
"row_num_val", (F.col("cnt") / 2).cast(IntegerType())
)
.filter( (F.col("row_num") <= F.col("row_num_val")) | (F.col("row_num_val") == 0) )
.drop(F.col("row_num"))
.drop(F.col("row_num_val"))
.drop(F.col("cnt"))
)
df.show()
output:
--- ------- -----
| id|item_id|score|
--- ------- -----
| 3| 6| 1.4|
| 7| 6| 1.3|
| 9| 2| 1.8|
| 1| 4| 2.0|
| 2| 4| 1.9|
--- ------- -----