Home > Software engineering >  Iterate over array column and stop as soon as condition is met in pyspark
Iterate over array column and stop as soon as condition is met in pyspark

Time:06-24

I have this dataframe:

se_cols = [
  "id",
  "name",
  "associated_countries"
]

person = (
  1,
  "GABRIELE",
  ["ITA", "BEL", "BVI"],
)

company = (
  2,
  "Bad Company",
  ["CYP", "RUS", "ITA"],
)
se_data = [person, company]

se = spark.createDataFrame(se_data).toDF(*se_cols)

Now, what I want, is to be able to iterate over each array in the "associated_countries" column, and as soon as I find one country that belongs to a certain set, select that row.

The way I could think of was to use F.exists with a dictionary whose keys are the ISO codes of the target countries I'm looking for.

secrecy = {"CYP":"cyprus", "BVI":"british virgin island"}

def at_least_one_secrecy(x_arr, secrecy_map=secrecy):
  for x in x_arr:
    if secrecy_map.get(x, False) is False:
      continue
    else:
      return True
  return False

se.withColumn("linked_to_secrecy", F.exists("associated_countries", lambda x_arr: at_least_one_secrecy(x_arr=x_arr))).show()

But this returns the error:

TypeError: Column is not iterable

PS: I know this could be solved by adding a column "target_countries" where each row would contain my target ISO as an array and do some sort of array_overlap > 0 condition between "associated_countries" and "terget_countries". But consider I have a huge dataset, and that would be very expensive.

CodePudding user response:

You can use arrays_overlap function with literal array that contains your ISO countries codes:

from pyspark.sql import functions as F,

secrecy_array = F.array(*[F.lit(x) for x in secrecy.keys()])

se.withColumn(
    "linked_to_secrecy", 
    F.arrays_overlap(F.col("associated_countries"), secrecy_array)
).show()

# --- ----------- -------------------- ----------------- 
#| id|       name|associated_countries|linked_to_secrecy|
# --- ----------- -------------------- ----------------- 
#|  1|   GABRIELE|     [ITA, BEL, BVI]|             true|
#|  2|Bad Company|     [CYP, RUS, ITA]|             true|
# --- ----------- -------------------- ----------------- 
  • Related