Home > Enterprise >  Pyspark - array column: combine all rows having at least one same value
Pyspark - array column: combine all rows having at least one same value

Time:08-18

I have a spark dataframe df that looks like this:

 ------------ 
|      values|
 ------------ 
|      [a, b]|
|[a, b, c, d]|
|   [a, e, f]|
|   [w, x, y]|
|      [x, z]|
 ------------ 

And I want to be able to get another dataframe that looks like this:

 ------------------- 
|             values|
 ------------------- 
| [a, b, c, d, e, f]|
|       [w, x, y, z]|
 ------------------- 

So what happened is that I'm combining all rows having at least one common value.

I'm aware that this thread exists: Spark get all rows with same values in array in column but I don't think I was able to get the answer I was looking for.

I also saw this one: Pyspark merge dataframe rows one array is contained in another So I tried copying the code of the accepted answer, but unfortunately, still not getting my desired output

from pyspark.sql.functions import expr
        
df_sub = df.alias('d1').join(df.alias('d2'), 
    expr('size(array_except(d2.values, d1.values))==0 AND size(d2.values) < size(d1.values)')
        ).select('d2.values').distinct()
    
df.join(df_sub , on=['values'], how='left_anti') \
  .withColumn('values', expr('sort_array(values)')) \
  .distinct() \
  .show()

Output:

 ------------ 
|      values|
 ------------ 
|   [a, e, f]|
|   [w, x, y]|
|[a, b, c, d]|
|      [x, z]|
 ------------ 

This is probably because the original problem has bounds on the maximum length of the array. How can I solve this?

CodePudding user response:

Given an input dataframe (say, data_sdf) as following

#  ------------ --- 
# |vals        |id |
#  ------------ --- 
# |[a, b]      |1  |
# |[a, b, c, d]|2  |
# |[a, e, f]   |3  |
# |[k, l, m]   |4  |
# |[w, x, y]   |5  |
# |[x, z]      |6  |
#  ------------ --- 

Notice the id field that I added. It has the data's sort order. Also, I added a new row (see id = 4) that will not be merged with others.

data_sdf. \
    withColumn('lead_vals', func.lead('vals').over(wd.orderBy('id'))). \
    withColumn('vals_nonoverlap_flg', 
               func.abs(func.arrays_overlap('vals', 'lead_vals').cast('int') - 1)
               ). \
    withColumn('blah', func.sum('vals_nonoverlap_flg').over(wd.orderBy('id'))). \
    withColumn('fnl_val_to_merge', 
               func.when(func.row_number().over(wd.orderBy('id')) == 1, 
                         func.array_union('vals', 'lead_vals')
                         ).
               otherwise(func.col('lead_vals'))
               ). \
    groupBy('blah'). \
    agg(func.array_distinct(func.flatten(func.collect_list('fnl_val_to_merge'))).alias('merged_val')). \
    drop('blah'). \
    show(truncate=False)

#  ------------------ 
# |merged_val        |
#  ------------------ 
# |[a, b, c, d, e, f]|
# |[k, l, m]         |
# |[w, x, y, z]      |
#  ------------------ 

P.S., add partitionBy() within the window wherever used.


You could also use aggregate.

data_sdf. \
    groupBy(func.lit('gk').alias('gk')). \
    agg(func.collect_list('vals').alias('vals_arr')). \
    withColumn('blah',
               func.expr('''
                         aggregate(slice(vals_arr, 2, size(vals_arr)), 
                                   array(vals_arr[0]),
                                   (x, y) -> if(arrays_overlap(x[size(x)-1], y), 
                                                array_union(slice(x, 1, size(x)-1), array(array_union(x[size(x)-1], y))), 
                                                array_union(x, array(y))
                                                )
                                   )
                         ''')
               ). \
    selectExpr('explode(blah) as merged_vals'). \
    show(truncate=False)

#  ------------------ 
# |merged_vals       |
#  ------------------ 
# |[a, b, c, d, e, f]|
# |[k, l, m]         |
# |[w, x, y, z]      |
#  ------------------ 
  • Related