I have a Dataframe that has a column "grades" containing a list of Grade objects that have 2 fields: name (String) and value (Double). I would like to add the word PASS to the list of tags if there is a Grade on the list with the name: HOME and a minimum value of 20.0. Example below:
INPUT:
------ ----- ---- ------- -------------------------------------------------------------
| model| cnd | age| tags | grades |
------ ----- ---- ------- -------------------------------------------------------------
| foo1| xx| 10| [] | [{name:"ATW", value: 10.0}, {name:"HOME", value: 20.0}] |
| foo2| xz| 12| [] | [{name:"ATW", value: 70.0}] |
| foo3| xc| 13| [] | [{name:"ATW", value: 90.0}, {name:"HOME", value: 10.0}] |
------ ----- ---- ------- -------------------------------------------------------------
OUTPUT:
------ ----- ---- ------- --------------------------------------------------------------
| model| cnd | age| tags | grades |
------ ----- ---- ------- --------------------------------------------------------------
| foo1| xx| 10| [PASS]| [{name:"ATW", value: 10.0}, {name:"HOME", value: 20.0}] |
| foo2| xz| 12| [] | [{name:"ATW", value: 70.0}] |
| foo3| xc| 13| [] | [{name:"ATW", value: 90.0}, {name:"HOME", value: 10.0}] |
------ ----- ---- ------- --------------------------------------------------------------
I haven't been able to find a reasonable solution. So far I have got this:
dataFrame.withColumn("tags",
when(
array_contains(
col("grades.name"),
lit("HOME")
) && col("grades.value") >= lit(20.0),
array_union(col("tags"), lit(Array("PASS")))
).otherwise(col("tags"))
But this code for some reason throws
org.apache.spark.sql.AnalysisException: cannot resolve '(`grades`.`value` >= 20.0D)' due to data type mismatch: differing types in '(`grades`.`value` >= 20.0D)' (array<double> and double).;;
The data is read from bigquery and there is no way that there is an array of double numbers in the value field.
CodePudding user response:
Assume data
is called your dataset (as below for the sake of simplicity):
---- ---------------------------
|tags|grades |
---- ---------------------------
|[] |[{ATW, 10.0}, {HOME, 20.0}]|
|[] |[{ATW, 70.0}] |
|[] |[{ATW, 90.0}, {HOME, 10.0}]|
---- ---------------------------
If by any case your column (grades
) is string, then we might want to convert the JSON to a structure as below (you can also skip this part):
data = data.withColumn("grades",
expr("from_json(grades, 'array<struct<name:string,value:double>>')")
)
Once this is in place, then we can apply the following:
data = data.withColumn("tags",
when(
// when this condition is met, meaning that if there is one combo name = HOME and value >= 20
expr("size(filter(grades, x -> x.name == 'HOME' and x.value >= 20))").geq(1),
// concatenate whatever there is in TAGS column with array("pass")
array_union(col("tags"), array(lit("PASS")))
// otherwise, do not touch TAGS column
).otherwise(col("tags")))
Final output looks like:
------ ---------------------------
|tags |grades |
------ ---------------------------
|[PASS]|[{ATW, 10.0}, {HOME, 20.0}]|
|[] |[{ATW, 70.0}] |
|[] |[{ATW, 90.0}, {HOME, 10.0}]|
------ ---------------------------
Good luck!