I have this dataframe:
se_cols = [
"id",
"name",
"associated_countries"
]
person = (
1,
"GABRIELE",
["ITA", "BEL", "BVI"],
)
company = (
2,
"Bad Company",
["CYP", "RUS", "ITA"],
)
se_data = [person, company]
se = spark.createDataFrame(se_data).toDF(*se_cols)
Now, what I want, is to be able to iterate over each array in the "associated_countries" column, and as soon as I find one country that belongs to a certain set, select that row.
The way I could think of was to use F.exists with a dictionary whose keys are the ISO codes of the target countries I'm looking for.
secrecy = {"CYP":"cyprus", "BVI":"british virgin island"}
def at_least_one_secrecy(x_arr, secrecy_map=secrecy):
for x in x_arr:
if secrecy_map.get(x, False) is False:
continue
else:
return True
return False
se.withColumn("linked_to_secrecy", F.exists("associated_countries", lambda x_arr: at_least_one_secrecy(x_arr=x_arr))).show()
But this returns the error:
TypeError: Column is not iterable
PS: I know this could be solved by adding a column "target_countries" where each row would contain my target ISO as an array and do some sort of array_overlap > 0 condition between "associated_countries" and "terget_countries". But consider I have a huge dataset, and that would be very expensive.
CodePudding user response:
You can use arrays_overlap
function with literal array that contains your ISO countries codes:
from pyspark.sql import functions as F,
secrecy_array = F.array(*[F.lit(x) for x in secrecy.keys()])
se.withColumn(
"linked_to_secrecy",
F.arrays_overlap(F.col("associated_countries"), secrecy_array)
).show()
# --- ----------- -------------------- -----------------
#| id| name|associated_countries|linked_to_secrecy|
# --- ----------- -------------------- -----------------
#| 1| GABRIELE| [ITA, BEL, BVI]| true|
#| 2|Bad Company| [CYP, RUS, ITA]| true|
# --- ----------- -------------------- -----------------