Home > front end >  Find next minimum value not used in df and assign it to current row
Find next minimum value not used in df and assign it to current row

Time:08-26

I have a df in this format:

 --- --- --- 
| id|  A|  B|
 --- --- --- 
| a1| 10| 15|
 --- --- --- 
| a1| 10| 18|
 --- --- --- 
| a1| 12| 15|
 --- --- --- 
| a1| 12| 18|
 --- --- --- 
| a1| 18| 18|
 --- --- --- 
| a1| 18| 24|
 --- --- --- 
| a1| 18| 27|
 --- --- --- 
| a1| 55| 57|
 --- --- --- 
| a1| 55| 59|
 --- --- --- 
| a1| 55| 61|
 --- --- --- 
| a1| 67| 75|
 --- --- --- 

And I would like the output to be:

 --- --- --- 
| id|  A|  B|
 --- --- --- 
| a1| 10| 15|
 --- --- --- 
| a1| 12| 18|
 --- --- --- 
| a1| 18| 24|
 --- --- --- 
| a1| 55| 57|
 --- --- --- 
| a1| 67| 75|
 --- --- --- 

The final df would keep or assign to column A the minimum value in B that has not been used by the row above. Thus the first row would keep/assign the minimum value 15 in column B to column A, and then compare the next row and determine the next largest number not equal to the value assigned above it (in this case the minimum value of 15 in column B cannot be assigned to 12 in column A since it had been assigned to 10 in column A in the row above it and thus 18 in column B gets assigned to 12 in column A). I have tried to obtain the distinct values in A and in B per id, assign row_number() and then join, but the problem is using row_number skews the data and loses records on the join such as 55 and 67 in column A and 57 and 75 in column B:

dfa = df.select('id', 'A').distinct()
dfa = dfa.withColumn('rnk', row_number().over(Window.partitionBy('id').orderBy('A')))

dfb = df.select('id', 'B').distinct()
dfb = dfb.withColumnRenamed('id', 'id1')
dfb = dfb.withColumn('rnk1', row_number().over(Window.partitionBy('id').orderBy('B')))

a_join_b = dfa.join(dfb, 
                    ((dfa['id'] == dfb['id1']) 
                     & (dfa['rnk'] == dfb['rnk1']) 
                     & (dfb['B'] > dfa['A'])), 
                    'inner').drop('id1', 'rnk1')

 --- --- --- --- 
| id|  A|  B|rnk|
 --- --- --- --- 
| a1| 10| 15|  1|
 --- --- --- --- 
| a1| 12| 18|  2|
 --- --- --- --- 
| a1| 18| 24|  3|
 --- --- --- --- 

CodePudding user response:

I was able to do the operation using arrays and aggregate function.

data_sdf. \
    groupBy('id', 'c1'). \
    agg(func.collect_list('c2').alias('c2_arr')). \
    withColumn('c1_c2arr_struct', func.struct('c1', 'c2_arr')). \
    groupBy('id'). \
    agg(func.array_sort(func.collect_list('c1_c2arr_struct')).alias('c1_c2arr_struct_arr')). \
    withColumn('c1_c2arr_struct_arr_new',
               func.expr('''
                         aggregate(slice(c1_c2arr_struct_arr, 2, size(c1_c2arr_struct_arr)),
                                   array(struct(c1_c2arr_struct_arr[0].c1 as c1, array_min(c1_c2arr_struct_arr[0].c2_arr) as c2_arr_custom_min)),
                                   (x, y) -> array_union(x, 
                                                         array(struct(y.c1 as c1, 
                                                                      array_min(array_remove(y.c2_arr, element_at(x, -1).c2_arr_custom_min)) as c2_arr_custom_min
                                                                      )
                                                               )
                                                         )
                                   )
                         ''')
               ). \
    selectExpr('id', 'inline(c1_c2arr_struct_arr_new)'). \
    show(truncate=False)

#  --- --- ----------------- 
# |id |c1 |c2_arr_custom_min|
#  --- --- ----------------- 
# |a1 |10 |15               |
# |a1 |12 |18               |
# |a1 |18 |24               |
# |a1 |55 |57               |
# |a1 |67 |75               |
#  --- --- ----------------- 

Explanation

Given the following input data

 --- --- --- 
| id| c1| c2|
 --- --- --- 
| a1| 10| 15|
| a1| 10| 18|
| a1| 12| 15|
| a1| 12| 18|
| a1| 18| 18|
| a1| 18| 24|
| a1| 18| 27|
| a1| 55| 57|
| a1| 55| 59|
| a1| 55| 61|
| a1| 67| 75|
 --- --- --- 

You create a list of the values you want the min for within each id and A partition. And, then create a struct with this list field and the A column.

data_sdf. \
    groupBy('id', 'c1'). \
    agg(func.collect_list('c2').alias('c2_arr')). \
    withColumn('c1_c2arr_struct', func.struct('c1', 'c2_arr')). \
    show(truncate=False)

#  --- --- ------------ ------------------ 
# |id |c1 |c2_arr      |c1_c2arr_struct   |
#  --- --- ------------ ------------------ 
# |a1 |18 |[18, 24, 27]|{18, [18, 24, 27]}|
# |a1 |55 |[57, 59, 61]|{55, [57, 59, 61]}|
# |a1 |67 |[75]        |{67, [75]}        |
# |a1 |10 |[15, 18]    |{10, [15, 18]}    |
# |a1 |12 |[15, 18]    |{12, [15, 18]}    |
#  --- --- ------------ ------------------ 

Create a list of these structs for each id.

data_sdf. \
    groupBy('id', 'c1'). \
    agg(func.collect_list('c2').alias('c2_arr')). \
    withColumn('c1_c2arr_struct', func.struct('c1', 'c2_arr')). \
    groupBy('id'). \
    agg(func.array_sort(func.collect_list('c1_c2arr_struct')).alias('c1_c2arr_struct_arr')). \
    show(truncate=False)

#  --- ------------------------------------------------------------------------------------ 
# |id |c1_c2arr_struct_arr                                                                 |
#  --- ------------------------------------------------------------------------------------ 
# |a1 |[{10, [15, 18]}, {12, [15, 18]}, {18, [18, 24, 27]}, {55, [57, 59, 61]}, {67, [75]}]|
#  --- ------------------------------------------------------------------------------------ 

Use aggregate on the list of structs. The aggregate function works like python's reduce and it takes in a lambda function which is applied recursively to all elements of an array (which in this case are structs).

The idea of the function to be applied is

  • calculate the initial min for the first A group
    • the second input to aggregate
  • remove the aforementioned min value from the next A group
    • the array_remove in the third input of aggregate
  • calculate the min of the remaining values in that A group
    • the array_min in the third input of aggregate
data_sdf. \
    groupBy('id', 'c1'). \
    agg(func.collect_list('c2').alias('c2_arr')). \
    withColumn('c1_c2arr_struct', func.struct('c1', 'c2_arr')). \
    groupBy('id'). \
    agg(func.array_sort(func.collect_list('c1_c2arr_struct')).alias('c1_c2arr_struct_arr')). \
    withColumn('c1_c2arr_struct_arr_new',
               func.expr('''
                         aggregate(slice(c1_c2arr_struct_arr, 2, size(c1_c2arr_struct_arr)),
                                   array(struct(c1_c2arr_struct_arr[0].c1 as c1, array_min(c1_c2arr_struct_arr[0].c2_arr) as c2_arr_custom_min)),
                                   (x, y) -> array_union(x, array(struct(y.c1 as c1, array_min(array_remove(y.c2_arr, x[size(x)-1].c2_arr_custom_min)) as c2_arr_custom_min)))
                                   )
                         ''')
               ). \
    show(truncate=False)

#  --- ------------------------------------------------------------------------------------ -------------------------------------------------- 
# |id |c1_c2arr_struct_arr                                                                 |c1_c2arr_struct_arr_new                           |
#  --- ------------------------------------------------------------------------------------ -------------------------------------------------- 
# |a1 |[{10, [15, 18]}, {12, [15, 18]}, {18, [18, 24, 27]}, {55, [57, 59, 61]}, {67, [75]}]|[{10, 15}, {12, 18}, {18, 24}, {55, 57}, {67, 75}]|
#  --- ------------------------------------------------------------------------------------ -------------------------------------------------- 

The inline SQL function explodes the array of structs and creates new columns using the struct fields.

  • Related