How can we conditionally explode multiple array columns in Spark SQL?
My input looks like this:
col_1 col2 col3
123 ["id_1","id_2"] ["tim","steve"]
456 ["id_3","id_4"] ["jenny"]
I need to transform this such that:
- The array items with the same index are mapped to the same line
- If there is only 1 entry in
col3
, it applies to every row
output should look like:
col_1 col2 col3
123 id_1 tim
123 id_2 steve
456 id_3 jenny
456 id_4 jenny
I have tried various combinations of explodes and lateral views, but each has returned a combination of rows that don't match the desired, or an error message.
df = spark.createDataFrame(
[
(123, ["id_1", "id_2"], ["tim", "steve"]),
(456, ["id_3", "id_4"], ["jenny"]),
],
["col1", "col2", "col3"]
)
df.createOrReplaceTempView("my_table")
spark.sql("""
select
col1,
col1_d,
col2_d
from my_table
lateral view explode(col2) exploded_col as col1_d
lateral view explode(col3) exploded_col_2 as col2_d
""").show()
---- ------ ------
|col1|col1_d|col2_d|
---- ------ ------
| 123| id_1| tim|
| 123| id_1| steve|
| 123| id_2| tim|
| 123| id_2| steve|
| 456| id_3| jenny|
| 456| id_4| jenny|
---- ------ ------
CodePudding user response:
Assuming that the data is properly validated for condition that:
- If "col3" > 1, then len("col2") == len("col3")
- Or "col3" == 1
You can achieve this by array_repeat
function to repeat "col3" for len("col2"). For corner case where "col2" is null, simply convert it to an empty array:
df = df.withColumn("col3",
F.when((F.size("col2") > 0)&(F.size("col3") == 1), F.array_repeat(F.element_at("col3", 1), F.size("col2"))) \
.otherwise(F.col("col3")) \
)
df = df.withColumn("col2",
F.when((F.col("col2").isNull())|(F.size("col2") == 0), F.array(F.lit(""))) \
.otherwise(F.col("col2")) \
)
---- ------------ --------------
|col1| col2| col3|
---- ------------ --------------
| 123|[id_1, id_2]| [tim, steve]|
| 456|[id_3, id_4]|[jenny, jenny]|
| 789| []| [harry]|
---- ------------ --------------
Then zip "col2" and "col3" together using arrays_zip
and finally explode
:
df = df \
.withColumn("col2_col3", F.explode(F.arrays_zip("col2", "col3"))) \
.select("col1", F.col("col2_col3.col2").alias("col2"), F.col("col2_col3.col3").alias("col3"))
---- ---- -----
|col1|col2| col3|
---- ---- -----
| 123|id_1| tim|
| 123|id_2|steve|
| 456|id_3|jenny|
| 456|id_4|jenny|
| 789| |harry|
---- ---- -----
Dataset used:
df = spark.createDataFrame(
[
(123, ["id_1", "id_2"], ["tim", "steve"]),
(456, ["id_3", "id_4"], ["jenny"]),
(789, None, ["harry"]),
],
["col1", "col2", "col3"]
)
---- ------------ ------------
|col1| col2| col3|
---- ------------ ------------
| 123|[id_1, id_2]|[tim, steve]|
| 456|[id_3, id_4]| [jenny]|
| 789| null| [harry]|
---- ------------ ------------