Home > Mobile >  Explode multiple array columns with variable lengths
Explode multiple array columns with variable lengths

Time:10-28

How can I explode multiple array columns with variable lengths and potential nulls?

My input data looks like this:

 ---- ------------ -------------- -------------------- 
|col1|        col2|          col3|                col4|
 ---- ------------ -------------- -------------------- 
|   1|[id_1, id_2]|  [tim, steve]|       [apple, pear]|
|   2|[id_3, id_4]|       [jenny]|           [avocado]|
|   3|        null|[tommy, megan]| [apple, strawberry]|
|   4|        null|          null|[banana, strawberry]|
 ---- ------------ -------------- -------------------- 

I need to explode this such that:

  1. Array items with the same index are mapped to the same row
  2. If there is only 1 entry in a column, it applies to every exploded row
  3. If an array is null, it applies to every row

My output should look like this:

 ---- ---- ----- ---------- 
|col1|col2|col3 |col4      |
 ---- ---- ----- ---------- 
|1   |id_1|tim  |apple     |
|1   |id_2|steve|pear      |
|2   |id_3|jenny|avocado   |
|2   |id_4|jenny|avocado   |
|3   |null|tommy|apple     |
|3   |null|megan|strawberry|
|4   |null|null |banana    |
|4   |null|null |strawberry|
 ---- ---- ----- ---------- 

I have been able to achieve this using the following code, but I feel like there must be a more straightforward approach:

df = spark.createDataFrame(
    [
        (1, ["id_1", "id_2"], ["tim", "steve"], ["apple", "pear"]),
        (2, ["id_3", "id_4"], ["jenny"], ["avocado"]),
        (3, None, ["tommy", "megan"], ["apple", "strawberry"]),
        (4, None, None, ["banana", "strawberry"])
    ],
    ["col1", "col2", "col3", "col4"]
)

df.createOrReplaceTempView("my_table")


spark.sql("""
with cte as (
  SELECT
  col1,
  col2,
  col3,
  col4,
  greatest(size(col2), size(col3), size(col4)) as max_array_len
  FROM my_table
), arrays_extended as (
select 
col1,
case 
  when col2 is null then array_repeat(null, max_array_len) 
  else col2 
end as col2, 
case
  when size(col3) = 1 then array_repeat(col3[0], max_array_len)
  when col3 is null then array_repeat(null, max_array_len)
  else col3
end as col3,
case
  when size(col4) = 1 then array_repeat(col4[0], max_array_len)
  when col4 is null then array_repeat(null, max_array_len)
  else col4
end as col4
from cte), 
arrays_zipped as (
select *, explode(arrays_zip(col2, col3, col4)) as zipped
from arrays_extended
)
select 
  col1,
  zipped.col2,
  zipped.col3,
  zipped.col4
from arrays_zipped
""").show(truncate=False)


CodePudding user response:

You can use inline_outer in conjuction with selectExpr and additionally coalesce for the first non-null to handle size mismatches within the different arrays

Data Preparation

inp_data = [
    (1,['id_1', 'id_2'],['tim', 'steve'],['apple', 'pear']),
    (2,['id_3', 'id_4'],['jenny'],['avocado']),
    (3,None,['tommy','megan'],['apple', 'strawberry']),
    (4,None,None,['banana', 'strawberry'])
]


inp_schema = StructType([
                      StructField('col1',IntegerType(),True)
                     ,StructField('col2',ArrayType(StringType(), True))
                     ,StructField('col3',ArrayType(StringType(), True))
                     ,StructField('col4',ArrayType(StringType(), True))
                    ]
                   )

sparkDF = sql.createDataFrame(data=inp_data,schema=inp_schema)\


sparkDF.show(truncate=False)

 ---- ------------ -------------- -------------------- 
|col1|col2        |col3          |col4                |
 ---- ------------ -------------- -------------------- 
|1   |[id_1, id_2]|[tim, steve]  |[apple, pear]       |
|2   |[id_3, id_4]|[jenny]       |[avocado]           |
|3   |null        |[tommy, megan]|[apple, strawberry] |
|4   |null        |null          |[banana, strawberry]|
 ---- ------------ -------------- -------------------- 

Inline Outer

sparkDF.selectExpr("col1"
                   ,"""inline_outer(arrays_zip(
                                       coalesce(col2,array()),
                                       coalesce(col3,array()),
                                       coalesce(col4,array())
                                    )
                )""").show(truncate=False)

 ---- ---- ----- ---------- 
|col1|0   |1    |2         |
 ---- ---- ----- ---------- 
|1   |id_1|tim  |apple     |
|1   |id_2|steve|pear      |
|2   |id_3|jenny|avocado   |
|2   |id_4|null |null      |
|3   |null|tommy|apple     |
|3   |null|megan|strawberry|
|4   |null|null |banana    |
|4   |null|null |strawberry|
 ---- ---- ----- ---------- 

CodePudding user response:

You can use an UDF function:

from pyspark.sql import functions as F, types as T

cols_of_interest = [c for c in df.columns if c != 'col1']

@F.udf(returnType=T.ArrayType(T.ArrayType(T.StringType())))
def get_sequences(*cols):
    """Equivalent of arrays_zip, but handling different lengths of the arrays.
       For shorter array than the maximum length last element is repeated.
    """
    
    # Get the length of the longest array in the row
    max_len = max(map(len, filter(lambda x: x, cols)))
    
    return list(zip(*[
        # create a list for each column with a length equal to the max_len.
        # If the original column has less elements than needed, repeat the last one.
        # None values will be filled with a list of Nones with length max_len.
        [c[min(i, len(c) - 1)]  for i in range(max_len)] if c else [None] * max_len for c in cols
    ]))

df2 = (
    df
    .withColumn('temp', F.explode(get_sequences(*cols_of_interest)))
    .select('col1', 
            *[F.col('temp').getItem(i).alias(c) for i, c in enumerate(cols_of_interest)])
)

df2 is the following DataFrame:

 ---- ---- ----- ---------- 
|col1|col2| col3|      col4|
 ---- ---- ----- ---------- 
|   1|id_1|  tim|     apple|
|   1|id_2|steve|      pear|
|   2|id_3|jenny|   avocado|
|   2|id_4|jenny|   avocado|
|   3|null|tommy|     apple|
|   3|null|megan|strawberry|
|   4|null| null|    banana|
|   4|null| null|strawberry|
 ---- ---- ----- ---------- 
  • Related