Home > Software design >  Pyspark: How to chain Column.when() using a dictionary with reduce?
Pyspark: How to chain Column.when() using a dictionary with reduce?

Time:02-27

I'm trying to get conditions from a dictionary in a chain of when() functions using reduce() to pass in the end to a dataframe.withColumn().

for example:

conditions = {
    "0": (col("a") == 1.0) & (col("b") != 1.0),
    "1": (col("c") == 1.0) & (col("d") == 1.0)
}

using reduce() I implemented this:

when_stats = reduce(lambda key, value: when(conditions[key], lit(key)), conditions)

and later using it in withColumn():

df2 = df1.withColumn(result, when_stats)

The problem is that it only takes the first condition which is "0" and doesn't chain the second one. printing 'when_stats' gives me:

Column<'CASE WHEN ((a = 1.0) AND (NOT (b = 1.0))) THEN 0 END'>

When I add a 3rd condition it throws an error and doesn't work:

TypeError: unhashable type: 'Column'

So the question is, how can I loop through the dictionary and create the complete when().when().when()... ? Is there a better solution specially if I want to have otherwise() in the end?

CodePudding user response:

When you use reduce with dict object, you're actually iterating over the keys of the dict. So the lambda function takes acc the accumulator and key the actual key being processed.

You can use this instead:

from functools import reduce
from pyspark.sql import functions as F

conditions = {
    "0": (F.col("a") == 1.0) & (F.col("b") != 1.0),
    "1": (F.col("c") == 1.0) & (F.col("d") == 1.0)
}

when_stats = reduce(
    lambda acc, key: acc.when(conditions[key], key),
    conditions,
    F
) #.otherwise("default_value")

print(when_stats)
# Column<'CASE WHEN ((a = 1.0) AND (NOT (b = 1.0))) THEN 0 WHEN ((c = 1.0) AND (d = 1.0)) THEN 1 END'>
  • Related