I would like to use the aggregation for each ID key to select rows with max(day).
ID | col1 | col2 | month | Day |
---|---|---|---|---|
AI1 | 5 | 2 | janv | 15 |
AI2 | 6 | 0 | Dec | 16 |
AI1 | 1 | 7 | March | 16 |
AI3 | 9 | 4 | Nov | 18 |
AI2 | 3 | 20 | Fev | 20 |
AI3 | 10 | 8 | June | 06 |
Desired result:
ID | col1 | col2 | month | Day |
---|---|---|---|---|
AI1 | 1 | 7 | March | 16 |
AI2 | 3 | 20 | Fev | 20 |
AI3 | 9 | 4 | Nov | 18 |
CodePudding user response:
The only solution that comes to my mind is to :
- Get the highest day for each ID (using groupBy)
- Append the value of the highest day to each line (with matching ID) using join
- Then a simple filter where the value of the two lines match
# select the max value for each of the ID
maxDayForIDs = df.groupBy("ID").max("day").withColumnRenamed("max(day)", "maxDay")
# now add the max value of the day for each line (with matching ID)
df = df.join(maxDayForIDs, "ID")
# keep only the lines where it matches "day" equals "maxDay"
df = df.filter(df.day == df.maxDay)
CodePudding user response:
Usually this kind of operation is done using window functions like
rank
,
dense_rank
or row_number
.
from pyspark.sql import functions as F, Window as W
df = spark.createDataFrame(
[('AI1', 5, 2, 'janv', '15'),
('AI2', 6, 0, 'Dec', '16'),
('AI1', 1, 7, 'March', '16'),
('AI3', 9, 4, 'Nov', '18'),
('AI2', 3, 20, 'Fev', '20'),
('AI3', 10, 8, 'June', '06')],
['ID', 'col1', 'col2', 'month', 'Day']
)
w = W.partitionBy('ID').orderBy(F.desc('Day'))
df = df.withColumn('_rn', F.row_number().over(w))
df = df.filter('_rn=1').drop('_rn')
df.show()
# --- ---- ---- ----- ---
# | ID|col1|col2|month|Day|
# --- ---- ---- ----- ---
# |AI1| 1| 7|March| 16|
# |AI2| 3| 20| Fev| 20|
# |AI3| 9| 4| Nov| 18|
# --- ---- ---- ----- ---