I have a large PySpark dataframe that includes these two columns:
highway | speed_kph |
---|---|
Road | 70 |
Service | 30 |
Road | null |
Road | 70 |
Service | null |
I'd like to fill the null
values by the mean for that highway
category.
I've tried creating another dataframe with groupBy, and ended up with this seconde one:
highway | avg(speed_kph) |
---|---|
Road | 65 |
Service | 30 |
But I don't know how to use this to fill in only the null
values, and not losing the original values if they exist.
The expected result for the first table would be:
highway | speed_kph |
---|---|
Road | 70 |
Service | 30 |
Road | 65 |
Road | 70 |
Service | 30 |
CodePudding user response:
The combination of case
when
and a window function partitioned by highway
would solve it fairly easy.
from pyspark.sql import functions as F
from pyspark.sql import Window as W
(df
.withColumn('speed_kph', F
.when(F.col('speed_kph').isNull(), F.mean('speed_kph').over(W.partitionBy('highway')))
.otherwise(F.col('speed_kph'))
)
.show()
)
# Output
# ------- ---------
# |highway|speed_kph|
# ------- ---------
# | Road| 70.0|
# | Road| 70.0|
# | Road| 70.0|
# |Service| 30.0|
# |Service| 30.0|
# ------- ---------