I have a spark dataframe df that looks like this:
------------
| values|
------------
| [a, b]|
|[a, b, c, d]|
| [a, e, f]|
| [w, x, y]|
| [x, z]|
------------
And I want to be able to get another dataframe that looks like this:
-------------------
| values|
-------------------
| [a, b, c, d, e, f]|
| [w, x, y, z]|
-------------------
So what happened is that I'm combining all rows having at least one common value.
I'm aware that this thread exists: Spark get all rows with same values in array in column but I don't think I was able to get the answer I was looking for.
I also saw this one: Pyspark merge dataframe rows one array is contained in another So I tried copying the code of the accepted answer, but unfortunately, still not getting my desired output
from pyspark.sql.functions import expr
df_sub = df.alias('d1').join(df.alias('d2'),
expr('size(array_except(d2.values, d1.values))==0 AND size(d2.values) < size(d1.values)')
).select('d2.values').distinct()
df.join(df_sub , on=['values'], how='left_anti') \
.withColumn('values', expr('sort_array(values)')) \
.distinct() \
.show()
Output:
------------
| values|
------------
| [a, e, f]|
| [w, x, y]|
|[a, b, c, d]|
| [x, z]|
------------
This is probably because the original problem has bounds on the maximum length of the array. How can I solve this?
CodePudding user response:
Given an input dataframe (say, data_sdf
) as following
# ------------ ---
# |vals |id |
# ------------ ---
# |[a, b] |1 |
# |[a, b, c, d]|2 |
# |[a, e, f] |3 |
# |[k, l, m] |4 |
# |[w, x, y] |5 |
# |[x, z] |6 |
# ------------ ---
Notice the id
field that I added. It has the data's sort order. Also, I added a new row (see id = 4
) that will not be merged with others.
data_sdf. \
withColumn('lead_vals', func.lead('vals').over(wd.orderBy('id'))). \
withColumn('vals_nonoverlap_flg',
func.abs(func.arrays_overlap('vals', 'lead_vals').cast('int') - 1)
). \
withColumn('blah', func.sum('vals_nonoverlap_flg').over(wd.orderBy('id'))). \
withColumn('fnl_val_to_merge',
func.when(func.row_number().over(wd.orderBy('id')) == 1,
func.array_union('vals', 'lead_vals')
).
otherwise(func.col('lead_vals'))
). \
groupBy('blah'). \
agg(func.array_distinct(func.flatten(func.collect_list('fnl_val_to_merge'))).alias('merged_val')). \
drop('blah'). \
show(truncate=False)
# ------------------
# |merged_val |
# ------------------
# |[a, b, c, d, e, f]|
# |[k, l, m] |
# |[w, x, y, z] |
# ------------------
P.S., add partitionBy()
within the window wherever used.
You could also use aggregate
.
data_sdf. \
groupBy(func.lit('gk').alias('gk')). \
agg(func.collect_list('vals').alias('vals_arr')). \
withColumn('blah',
func.expr('''
aggregate(slice(vals_arr, 2, size(vals_arr)),
array(vals_arr[0]),
(x, y) -> if(arrays_overlap(x[size(x)-1], y),
array_union(slice(x, 1, size(x)-1), array(array_union(x[size(x)-1], y))),
array_union(x, array(y))
)
)
''')
). \
selectExpr('explode(blah) as merged_vals'). \
show(truncate=False)
# ------------------
# |merged_vals |
# ------------------
# |[a, b, c, d, e, f]|
# |[k, l, m] |
# |[w, x, y, z] |
# ------------------