What I have:
country | sources | infer_from_source
-----------------------------------------------
null | ["LUX", "CZE", | ["FALSE", "TRUE",
| "CHN", "FRA"] | "FALSE", "TRUE"]
"DEU" | ["DEU"] | ["FALSE"]
What I want after a function:
country | sources | infer_from_source | inferred_country
--------------------------------------------------------------------
null | ["LUX", "CZE", | ["FALSE", "TRUE", |
| "CHN", "FRA"] | "FALSE", "TRUE"] | ["CZE", "FRA"]
"DEU" | ["DEU"] | ["FALSE"] | "DEU"
I need to create a function that
if country
column is null, extracts the countries from the sources
array based on the boolean values in the infer_from_source
column array, otherwise it should give back the country
value.
I created this function
from pyspark.sql.types import BooleanType, IntegerType, StringType, FloatType, ArrayType
import pyspark.sql.functions as F
@udf
def determine_entity_country(country: StringType, sources: ArrayType,
infer_from_source: ArrayType) -> ArrayType:
if country:
return country_value
else:
if "TRUE" in infer_from_source:
idx = infer_from_source.index("TRUE")
return sources[idx]
return None
But this yields - basically the .index("TRUE")
method returns the index of the first element that matches its argument only.
country | sources | infer_from_source | inferred_country
--------------------------------------------------------------------
null | ["LUX", "CZE", | ["FALSE", "TRUE", |
| "CHN", "FRA"] | "FALSE", "TRUE"] | "CZE"
"DEU" | ["DEU"] | ["FALSE"] | "DEU"
CodePudding user response:
Fixed it! Was simply a list comprehension matter
@udf
def determine_entity_country(country: StringType, sources: ArrayType,
infer_from_source: ArrayType) -> ArrayType:
if country:
return country_value
else:
if "TRUE" in infer_from_source:
max_ix = len(infer_from_source)
true_index_array = [x for x in range(0, max_ix) if infer_from_source[x] == "TRUE"]
return [sources[ix] for ix in true_index_array]
return None
CodePudding user response:
You should avoid using UDFs whenever you could achieve the same only with Spark builtin functions especially when it comes to Pyspark UDFs.
Here's another way using higher order functions transform
filter
on arrays:
import pyspark.sql.functions as F
df1 = df.withColumn(
"inferred_country",
F.when(
F.col("country").isNotNull(),
F.array(F.col("country"))
).otherwise(
F.expr("""filter(
transform(sources, (x, i) -> IF(boolean(infer_from_source[i]), x, null)),
x -> x is not null
)""")
)
)
df1.show()
# ------- -------------------- -------------------- ----------------
#|country| sources| infer_from_source|inferred_country|
# ------- -------------------- -------------------- ----------------
#| null|[LUX, CZE, CHN, FRA]|[FALSE, TRUE, FAL...| [CZE, FRA]|
#| DEU| [DEU]| [FALSE]| [DEU]|
# ------- -------------------- -------------------- ----------------
And starting from Spark 3 , you can use index in filter lambda function :
df1 = df.withColumn(
"inferred_country",
F.when(
F.col("country").isNotNull(),
F.array(F.col("country"))
).otherwise(
F.expr("filter(sources, (x, i) -> boolean(infer_from_source[i]))")
)
)