i have a dataframe df as shown below:
VehNum Control_circuit control_circuit_status partnumbers errors Flag
4234456 DOC ok A567UR Software Issue 0
4234456 DOC not_okay A568UR Software Issue 1
4234456 DOC not_okay A569UR Hardware issue 2
4234457 ACR ok A234TY Hardware issue 0
4234457 ACR ok A235TY Hardware issue 0
4234457 ACR ok A234TY Hardware issue 0
4234487 QWR ok A276TY Hardware issue 0
4234487 QWR not_okay A872UR Hardware issue 1
3423448 QWR not_okay A872UR Hardware issue 1
i want to add a new column called "Control_Flag" and perform below operations: for each VehNum ,Control_circuit if it has flag value only 0 then Control_Flag column will hold value 0 else if it has 0 ,1 or 2 then Control_Flag column will hold value 1.
result should be as below:
VehNum Control_circuit control_circuit_status partnumbers errors Flag Control_Flag
4234456 DOC ok A567UR Software Issue 0 1
4234456 DOC not_okay A568UR Software Issue 1 1
4234456 DOC not_okay A569UR Hardware issue 2 1
4234457 ACR ok A234TY Hardware issue 0 0
4234457 ACR ok A235TY Hardware issue 0 0
4234457 ACR ok A234TY Hardware issue 0 0
4234487 QWR ok A276TY Hardware issue 0 1
4234487 QWR not_okay A872UR Hardware issue 1 1
3423448 QWR not_okay A872UR Hardware issue 1 1
how to achieve this using pyspark?
CodePudding user response:
using a aggregate window with SUM() will help achieve this
from pyspark.sql import functions as F
from pyspark.sql.types import *
from pyspark.sql import Window
df = spark.createDataFrame(
[
("4234456", "DOC", "ok", "A567UR", "Software Issue", 0),
("4234456", "DOC", "not_okay", "A568UR", "Software Issue", 1),
("4234456", "DOC", "not_okay", "A569UR", "Hardware Issue", 2),
("4234457", "ACR", "ok", "A234TY", "Hardware Issue", 0),
("4234457", "ACR", "ok", "A234TY", "Hardware Issue", 0),
("4234457", "ACR", "ok", "A234TY", "Hardware Issue", 0),
("4234487", "QWR", "ok", "A276TY", "Hardware Issue", 0),
("4234487", "QWR", "not_okay", "A872UR", "Hardware Issue", 1),
("3423448", "QWR", "not_okay", "A872UR", "Hardware Issue", 1),
],
["VehNum", "Control_circuit", "control_circuit_status", "partnumbers", "errors", "Flag"],
)
df_agg_window = Window.partitionBy(
"VehNum",
"Control_circuit",
)
df = (
df
.withColumn(
"flag_sum",
F.sum("Flag").over(df_agg_window),
)
.withColumn(
"Control_Flag",
F.when(
F.lower(F.col("flag_sum")) > 0,
F.lit(1),
)
.otherwise(F.lit(0)),
)
#.drop(F.col("flag_sum"))
)
df.show()
output:
------- --------------- ---------------------- ----------- -------------- ---- -------- ------------
| VehNum|Control_circuit|control_circuit_status|partnumbers| errors|Flag|flag_sum|Control_Flag|
------- --------------- ---------------------- ----------- -------------- ---- -------- ------------
|4234457| ACR| ok| A234TY|Hardware Issue| 0| 0| 0|
|4234457| ACR| ok| A234TY|Hardware Issue| 0| 0| 0|
|4234457| ACR| ok| A234TY|Hardware Issue| 0| 0| 0|
|4234487| QWR| not_okay| A872UR|Hardware Issue| 1| 1| 1|
|4234487| QWR| ok| A276TY|Hardware Issue| 0| 1| 1|
|4234456| DOC| ok| A567UR|Software Issue| 0| 3| 1|
|4234456| DOC| not_okay| A569UR|Hardware Issue| 2| 3| 1|
|4234456| DOC| not_okay| A568UR|Software Issue| 1| 3| 1|
|3423448| QWR| not_okay| A872UR|Hardware Issue| 1| 1| 1|
------- --------------- ---------------------- ----------- -------------- ---- -------- ------------