Home > Net >  Chain several WHEN conditions in a scalable way in PySpark
Chain several WHEN conditions in a scalable way in PySpark

Time:07-10

I have a dictionary (variable pats) with many when arguments: conditions and values.

from pyspark.sql import functions as F
df = spark.createDataFrame([("ė",), ("2",), ("",), ("@",)], ["col1"])

pats = {
  r"^\d$"          :"digit",
  r"^\p{L}$"       :"letter",
  r"^[\p{P}\p{S}]$":"spec_char",
  r"^$"            :"empty"
}

whens = (
    F.when(F.col("col1").rlike(list(pats.keys())[0]), pats[list(pats.keys())[0]])
     .when(F.col("col1").rlike(list(pats.keys())[1]), pats[list(pats.keys())[1]])
     .when(F.col("col1").rlike(list(pats.keys())[2]), pats[list(pats.keys())[2]])
     .when(F.col("col1").rlike(list(pats.keys())[3]), pats[list(pats.keys())[3]])
     .otherwise(F.col("col1"))
)
df = df.withColumn("col2", whens)

df.show()
#  ---- --------- 
# |col1|     col2|
#  ---- --------- 
# |   ė|   letter|
# |   2|    digit|
# |    |    empty|
# |   @|spec_char|
#  ---- --------- 

I'm looking for a scalable way to chain all the when conditions, so I wouldn't need to write a line for every key.

CodePudding user response:

reduce can be used.

from functools import reduce

whens = reduce(
    lambda acc, p: acc.when(F.col("col1").rlike(p), pats[p]),
    list(pats.keys()),
    F.when(F.lit(False), "1")
).otherwise(F.col("col1"))

Full code:

from pyspark.sql import functions as F
from functools import reduce
df = spark.createDataFrame([("ė",), ("2",), ("",), ("@",)], ["col1"])

pats = {
  r"^\d$"          :"digit",
  r"^\p{L}$"       :"letter",
  r"^[\p{P}\p{S}]$":"spec_char",
  r"^$"            :"empty"
}

whens = reduce(
    lambda acc, p: acc.when(F.col("col1").rlike(p), pats[p]),
    pats.keys(),
    F
).otherwise(F.col("col1"))

df = df.withColumn("col2", whens)

df.show()
#  ---- --------- 
# |col1|     col2|
#  ---- --------- 
# |   ė|   letter|
# |   2|    digit|
# |    |    empty|
# |   @|spec_char|
#  ---- --------- 
  • Related