Home > Software engineering >  get index of all True elements in array column in pyspark
get index of all True elements in array column in pyspark

Time:11-15

What I have:

country | sources        |  infer_from_source   
-----------------------------------------------
null    | ["LUX", "CZE", |  ["FALSE", "TRUE",   
        |  "CHN", "FRA"] |   "FALSE", "TRUE"]   
"DEU"   | ["DEU"]        |  ["FALSE"]          

What I want after a function:

country | sources        |  infer_from_source   | inferred_country
--------------------------------------------------------------------
null    | ["LUX", "CZE", |  ["FALSE", "TRUE",   | 
        |  "CHN", "FRA"] |   "FALSE", "TRUE"]   | ["CZE", "FRA"]
"DEU"   | ["DEU"]        |  ["FALSE"]           | "DEU"

I need to create a function that

if country column is null, extracts the countries from the sources array based on the boolean values in the infer_from_source column array, otherwise it should give back the country value.

I created this function

from pyspark.sql.types import BooleanType, IntegerType, StringType, FloatType, ArrayType
import pyspark.sql.functions as F


@udf
def determine_entity_country(country: StringType, sources: ArrayType, 
                             infer_from_source: ArrayType) -> ArrayType:
    if country:
        return country_value
    else:
       if "TRUE" in infer_from_source:
           idx = infer_from_source.index("TRUE")
           return sources[idx]
  return None

But this yields - basically the .index("TRUE") method returns the index of the first element that matches its argument only.

country | sources        |  infer_from_source   | inferred_country
--------------------------------------------------------------------
null    | ["LUX", "CZE", |  ["FALSE", "TRUE",   | 
        |  "CHN", "FRA"] |   "FALSE", "TRUE"]   | "CZE"
"DEU"   | ["DEU"]        |  ["FALSE"]           | "DEU"

CodePudding user response:

Fixed it! Was simply a list comprehension matter

@udf
def determine_entity_country(country: StringType, sources: ArrayType, 
                             infer_from_source: ArrayType) -> ArrayType:
    if country:
        return country_value
    else:
       if "TRUE" in infer_from_source:
            max_ix = len(infer_from_source)
            true_index_array = [x for x in range(0, max_ix) if infer_from_source[x] == "TRUE"]
            return [sources[ix] for ix in true_index_array] 
  return None

CodePudding user response:

You should avoid using UDFs whenever you could achieve the same only with Spark builtin functions especially when it comes to Pyspark UDFs.

Here's another way using higher order functions transform filter on arrays:

import pyspark.sql.functions as F

df1 = df.withColumn(
    "inferred_country",
    F.when(
        F.col("country").isNotNull(),
        F.array(F.col("country"))
    ).otherwise(
        F.expr("""filter(
                    transform(sources, (x, i) -> IF(boolean(infer_from_source[i]), x, null)),
                    x -> x is not null
                )""")
    )
)

df1.show()
# ------- -------------------- -------------------- ---------------- 
#|country|             sources|   infer_from_source|inferred_country|
# ------- -------------------- -------------------- ---------------- 
#|   null|[LUX, CZE, CHN, FRA]|[FALSE, TRUE, FAL...|      [CZE, FRA]|
#|    DEU|               [DEU]|             [FALSE]|           [DEU]|
# ------- -------------------- -------------------- ---------------- 

And starting from Spark 3 , you can use index in filter lambda function :

df1 = df.withColumn(
    "inferred_country",
    F.when(
        F.col("country").isNotNull(),
        F.array(F.col("country"))
    ).otherwise(
        F.expr("filter(sources, (x, i) -> boolean(infer_from_source[i]))")
    )
)
  • Related