Home > Enterprise >  How to case when pyspark dataframe array based on multiple values
How to case when pyspark dataframe array based on multiple values

Time:11-06

I can use array_contains to check whether an array contains a value.

test = test.withColumn("my_boolean", 
    F.when(expr("array_contains('check_variable', 'a')"),
           1)
 .otherwise(0))

Instead of testing for one value I would like to test for multiple values. I could nest:

test = test.withColumn("my_boolean", 
    F.when(expr("array_contains('check_variable', 'a')"),
           1)
    F.when(expr("array_contains('check_variable', 'b')"),
           1)
 .otherwise(0))

Is there a way to do this in one statement, pseudo code:

test = test.withColumn("my_boolean", 
    F.when(expr("array_contains('check_variable', ['a','b'])"),
           1)
 .otherwise(0))

CodePudding user response:

You could use array_intersect on two arrays and if the intersection is greater then 0 then you have at least one value in the array.

Example:

spark = SparkSession.builder.getOrCreate()
data = [
    {"id": 1, "test": ["A", "B"]},
    {"id": 2, "test": ["E", "C"]},
    {"id": 3, "test": ["D", "B"]},
]
df = spark.createDataFrame(data)
df = df.withColumn(
    "result",
    F.when(
        F.size(F.array_intersect(F.col("test"), F.array(F.lit("A"), F.lit("B"))))
        > 0,
        1,
    ).otherwise(0),
)

Result:

 --- ------ ------                                                              
| id|  test|result|
 --- ------ ------ 
|  1|[A, B]|     1|
|  2|[E, C]|     0|
|  3|[D, B]|     1|
 --- ------ ------ 

If you want to use expr:

F.when(
    F.expr("size(array_intersect(test, array('A', 'B'))) > 0"),
    1,
).otherwise(0)

CodePudding user response:

In spark>=2.4, you could use array_intersect and check that the output has the same size as the number of values you are looking for (2 in your example).

pyspark.sql.functions.array_intersect(col1, col2)

Collection function: returns an array of the elements in the intersection of col1 and col2, without duplicates.

The code could be as follows:

test = test.withColumn("my_boolean",
    f.expr("size(array_intersect(check_variable, array(a, b))) > 0").cast("int"))

Note that another way to transform a boolean into a 0/1 value is to cast it into an int.

In case anyone is interested in a spark<2.4 solution, one could construct a function based on array_contains and by iterating over the array of columns:

from functools import reduce
def contains_at_least_one(a):
    contains = map(lambda v: f.array_contains('check_variable', f.col(v)), a)
    return reduce(lambda x, y: x | y, contains)

test = test.withColumn("my_boolean",
    contains_at_least_one(['a', 'b']).cast("int"))
  • Related