Home > Net >  Result of a when chain in Spark
Result of a when chain in Spark

Time:09-15

I have a chained when condition in a Spark DataFrame which looks something like this:

df = df.withColumn("some_column", when((lower(df.transaction_id) == "id1") & (df.some_qty != 0), df.some_qty)
                                 .when((lower(df.transaction_id) == "id1") & (df.some_qty == 0) & (df.some_qty2 != 0), df.some_qty2)
                                 .when((lower(df.transaction_id) == "id1") & (df.some_qty == 0) & (df.some_qty2 == 0), 0)
                                 .when((lower(df.transaction_id) == "id2") & (df.some_qty3 != 0), df.some_qty3)
                                 .when((lower(df.transaction_id) == "id2") & (df.some_qty3 == 0) & (df.some_qty4 != 0), df.some_qty4)
                                 .when((lower(df.transaction_id) == "id2") & (df.some_qty3 == 0) & (df.some_qty4 == 0), 0))

In the expression, I'm trying to modify the value of a column based on the values of other columns. I wanted to understand the execution of the above statement. As in, are all the conditions checked for every row of the dataframe and if yes what happens when more than one when condition is true. Or is it the case the the order of chain is followed and the first one to be true is used?

CodePudding user response:

Yes, every row is going to be checked. But spark takes care of optimizing that, so it's not like looping for each cell.

As for the order, with an exemple we can see that the first one is taken into account:

df = spark.createDataFrame(
    [
    ('id2','70.07','22.1','0','1'),
    ('id1','0','0','1','3'),
    ('id2','80.7','0','1','3'),
    ('id2','0','0','1','3'),
    ('id1','22.2','0','1','3')
    ],
    ['transaction_id','some_qty','some_qty2', 'some_qty3','some_qty4']
)\
    .withColumn('some_qty', F.col('some_qty').cast('double'))\
    .withColumn('some_qty2', F.col('some_qty2').cast('double'))\
    .withColumn('some_qty3', F.col('some_qty3').cast('double'))\
    .withColumn('some_qty4', F.col('some_qty4').cast('double'))\

from pyspark.sql.functions import when, lower, lit

df = df.withColumn("some_column",when((lower(df.transaction_id) == "id1") & (df.some_qty != 0),lit('first_true'))
                                 .when((lower(df.transaction_id) == "id1") & (df.some_qty != 0),lit('second_true')))
df.show()

#  -------------- -------- --------- --------- --------- ----------- 
# |transaction_id|some_qty|some_qty2|some_qty3|some_qty4|some_column|
#  -------------- -------- --------- --------- --------- ----------- 
# |           id2|   70.07|     22.1|      0.0|      1.0|       null|
# |           id1|     0.0|      0.0|      1.0|      3.0|       null|
# |           id2|    80.7|      0.0|      1.0|      3.0|       null|
# |           id2|     0.0|      0.0|      1.0|      3.0|       null|
# |           id1|    22.2|      0.0|      1.0|      3.0| first_true|
#  -------------- -------- --------- --------- --------- ----------- 

CodePudding user response:

You can create a sample dataframe and check it. As can be seen,
when both conditions are true, only the first returns the result.

from pyspark.sql import functions as F
df = spark.createDataFrame([(2,), (2,), (0,)], ['col_1'])

df = df.withColumn("some_column", F.when(F.col('col_1') == 2, 'condition1')
                                   .when(F.col('col_1') > 1, 'condition2'))
df.show()
#  ----- ----------- 
# |col_1|some_column|
#  ----- ----------- 
# |    2| condition1|
# |    2| condition1|
# |    0|       null|
#  ----- ----------- 

Since there's no otherwise clause, everything what is not caught by any when condition causes the column to return null.

CodePudding user response:

Adding to the above, the chained .when methods do need to evaluate correctly. The below gives an error, even though the second when isn't ever used:

df = spark.createDataFrame(
    [
    ('id2','70.07','22.1','0','1'),
    ('id1','0','0','1','3'),
    ('id2','80.7','0','1','3'),
    ('id2','0','0','1','3'),
    ('id1','22.2','0','1','3')
    ],
    ['transaction_id','some_qty','some_qty2', 'some_qty3','some_qty4']
)

df = df.withColumn("some_column", F.when(F.col('some_qty4') == 2, F.lit(1))
                                   .when(F.col('some_qty4') > 1, F.col('undefined_col')))
df.show()

This differs from python if...else logic where the following runs quite happily:

if False:
    print(undefined_variable)
else:
    print("ok")
  • Related