How can I explode multiple array columns with variable lengths and potential nulls?
My input data looks like this:
---- ------------ -------------- --------------------
|col1| col2| col3| col4|
---- ------------ -------------- --------------------
| 1|[id_1, id_2]| [tim, steve]| [apple, pear]|
| 2|[id_3, id_4]| [jenny]| [avocado]|
| 3| null|[tommy, megan]| [apple, strawberry]|
| 4| null| null|[banana, strawberry]|
---- ------------ -------------- --------------------
I need to explode this such that:
- Array items with the same index are mapped to the same row
- If there is only 1 entry in a column, it applies to every exploded row
- If an array is null, it applies to every row
My output should look like this:
---- ---- ----- ----------
|col1|col2|col3 |col4 |
---- ---- ----- ----------
|1 |id_1|tim |apple |
|1 |id_2|steve|pear |
|2 |id_3|jenny|avocado |
|2 |id_4|jenny|avocado |
|3 |null|tommy|apple |
|3 |null|megan|strawberry|
|4 |null|null |banana |
|4 |null|null |strawberry|
---- ---- ----- ----------
I have been able to achieve this using the following code, but I feel like there must be a more straightforward approach:
df = spark.createDataFrame(
[
(1, ["id_1", "id_2"], ["tim", "steve"], ["apple", "pear"]),
(2, ["id_3", "id_4"], ["jenny"], ["avocado"]),
(3, None, ["tommy", "megan"], ["apple", "strawberry"]),
(4, None, None, ["banana", "strawberry"])
],
["col1", "col2", "col3", "col4"]
)
df.createOrReplaceTempView("my_table")
spark.sql("""
with cte as (
SELECT
col1,
col2,
col3,
col4,
greatest(size(col2), size(col3), size(col4)) as max_array_len
FROM my_table
), arrays_extended as (
select
col1,
case
when col2 is null then array_repeat(null, max_array_len)
else col2
end as col2,
case
when size(col3) = 1 then array_repeat(col3[0], max_array_len)
when col3 is null then array_repeat(null, max_array_len)
else col3
end as col3,
case
when size(col4) = 1 then array_repeat(col4[0], max_array_len)
when col4 is null then array_repeat(null, max_array_len)
else col4
end as col4
from cte),
arrays_zipped as (
select *, explode(arrays_zip(col2, col3, col4)) as zipped
from arrays_extended
)
select
col1,
zipped.col2,
zipped.col3,
zipped.col4
from arrays_zipped
""").show(truncate=False)
CodePudding user response:
You can use inline_outer in conjuction with selectExpr
and additionally coalesce
for the first non-null to handle size mismatches within the different arrays
Data Preparation
inp_data = [
(1,['id_1', 'id_2'],['tim', 'steve'],['apple', 'pear']),
(2,['id_3', 'id_4'],['jenny'],['avocado']),
(3,None,['tommy','megan'],['apple', 'strawberry']),
(4,None,None,['banana', 'strawberry'])
]
inp_schema = StructType([
StructField('col1',IntegerType(),True)
,StructField('col2',ArrayType(StringType(), True))
,StructField('col3',ArrayType(StringType(), True))
,StructField('col4',ArrayType(StringType(), True))
]
)
sparkDF = sql.createDataFrame(data=inp_data,schema=inp_schema)\
sparkDF.show(truncate=False)
---- ------------ -------------- --------------------
|col1|col2 |col3 |col4 |
---- ------------ -------------- --------------------
|1 |[id_1, id_2]|[tim, steve] |[apple, pear] |
|2 |[id_3, id_4]|[jenny] |[avocado] |
|3 |null |[tommy, megan]|[apple, strawberry] |
|4 |null |null |[banana, strawberry]|
---- ------------ -------------- --------------------
Inline Outer
sparkDF.selectExpr("col1"
,"""inline_outer(arrays_zip(
coalesce(col2,array()),
coalesce(col3,array()),
coalesce(col4,array())
)
)""").show(truncate=False)
---- ---- ----- ----------
|col1|0 |1 |2 |
---- ---- ----- ----------
|1 |id_1|tim |apple |
|1 |id_2|steve|pear |
|2 |id_3|jenny|avocado |
|2 |id_4|null |null |
|3 |null|tommy|apple |
|3 |null|megan|strawberry|
|4 |null|null |banana |
|4 |null|null |strawberry|
---- ---- ----- ----------
CodePudding user response:
You can use an UDF function:
from pyspark.sql import functions as F, types as T
cols_of_interest = [c for c in df.columns if c != 'col1']
@F.udf(returnType=T.ArrayType(T.ArrayType(T.StringType())))
def get_sequences(*cols):
"""Equivalent of arrays_zip, but handling different lengths of the arrays.
For shorter array than the maximum length last element is repeated.
"""
# Get the length of the longest array in the row
max_len = max(map(len, filter(lambda x: x, cols)))
return list(zip(*[
# create a list for each column with a length equal to the max_len.
# If the original column has less elements than needed, repeat the last one.
# None values will be filled with a list of Nones with length max_len.
[c[min(i, len(c) - 1)] for i in range(max_len)] if c else [None] * max_len for c in cols
]))
df2 = (
df
.withColumn('temp', F.explode(get_sequences(*cols_of_interest)))
.select('col1',
*[F.col('temp').getItem(i).alias(c) for i, c in enumerate(cols_of_interest)])
)
df2
is the following DataFrame
:
---- ---- ----- ----------
|col1|col2| col3| col4|
---- ---- ----- ----------
| 1|id_1| tim| apple|
| 1|id_2|steve| pear|
| 2|id_3|jenny| avocado|
| 2|id_4|jenny| avocado|
| 3|null|tommy| apple|
| 3|null|megan|strawberry|
| 4|null| null| banana|
| 4|null| null|strawberry|
---- ---- ----- ----------