I have a chained when
condition in a Spark DataFrame which looks something like this:
df = df.withColumn("some_column", when((lower(df.transaction_id) == "id1") & (df.some_qty != 0), df.some_qty)
.when((lower(df.transaction_id) == "id1") & (df.some_qty == 0) & (df.some_qty2 != 0), df.some_qty2)
.when((lower(df.transaction_id) == "id1") & (df.some_qty == 0) & (df.some_qty2 == 0), 0)
.when((lower(df.transaction_id) == "id2") & (df.some_qty3 != 0), df.some_qty3)
.when((lower(df.transaction_id) == "id2") & (df.some_qty3 == 0) & (df.some_qty4 != 0), df.some_qty4)
.when((lower(df.transaction_id) == "id2") & (df.some_qty3 == 0) & (df.some_qty4 == 0), 0))
In the expression, I'm trying to modify the value of a column based on the values of other columns. I wanted to understand the execution of the above statement. As in, are all the conditions checked for every row of the dataframe and if yes what happens when more than one when
condition is true. Or is it the case the the order of chain is followed and the first one to be true is used?
CodePudding user response:
Yes, every row is going to be checked. But spark takes care of optimizing that, so it's not like looping for each cell.
As for the order, with an exemple we can see that the first one is taken into account:
df = spark.createDataFrame(
[
('id2','70.07','22.1','0','1'),
('id1','0','0','1','3'),
('id2','80.7','0','1','3'),
('id2','0','0','1','3'),
('id1','22.2','0','1','3')
],
['transaction_id','some_qty','some_qty2', 'some_qty3','some_qty4']
)\
.withColumn('some_qty', F.col('some_qty').cast('double'))\
.withColumn('some_qty2', F.col('some_qty2').cast('double'))\
.withColumn('some_qty3', F.col('some_qty3').cast('double'))\
.withColumn('some_qty4', F.col('some_qty4').cast('double'))\
from pyspark.sql.functions import when, lower, lit
df = df.withColumn("some_column",when((lower(df.transaction_id) == "id1") & (df.some_qty != 0),lit('first_true'))
.when((lower(df.transaction_id) == "id1") & (df.some_qty != 0),lit('second_true')))
df.show()
# -------------- -------- --------- --------- --------- -----------
# |transaction_id|some_qty|some_qty2|some_qty3|some_qty4|some_column|
# -------------- -------- --------- --------- --------- -----------
# | id2| 70.07| 22.1| 0.0| 1.0| null|
# | id1| 0.0| 0.0| 1.0| 3.0| null|
# | id2| 80.7| 0.0| 1.0| 3.0| null|
# | id2| 0.0| 0.0| 1.0| 3.0| null|
# | id1| 22.2| 0.0| 1.0| 3.0| first_true|
# -------------- -------- --------- --------- --------- -----------
CodePudding user response:
You can create a sample dataframe and check it. As can be seen,
when both conditions are true, only the first returns the result.
from pyspark.sql import functions as F
df = spark.createDataFrame([(2,), (2,), (0,)], ['col_1'])
df = df.withColumn("some_column", F.when(F.col('col_1') == 2, 'condition1')
.when(F.col('col_1') > 1, 'condition2'))
df.show()
# ----- -----------
# |col_1|some_column|
# ----- -----------
# | 2| condition1|
# | 2| condition1|
# | 0| null|
# ----- -----------
Since there's no otherwise
clause, everything what is not caught by any when
condition causes the column to return null.
CodePudding user response:
Adding to the above, the chained .when
methods do need to evaluate correctly. The below gives an error, even though the second when
isn't ever used:
df = spark.createDataFrame(
[
('id2','70.07','22.1','0','1'),
('id1','0','0','1','3'),
('id2','80.7','0','1','3'),
('id2','0','0','1','3'),
('id1','22.2','0','1','3')
],
['transaction_id','some_qty','some_qty2', 'some_qty3','some_qty4']
)
df = df.withColumn("some_column", F.when(F.col('some_qty4') == 2, F.lit(1))
.when(F.col('some_qty4') > 1, F.col('undefined_col')))
df.show()
This differs from python if...else logic where the following runs quite happily:
if False:
print(undefined_variable)
else:
print("ok")