I have a df in this format:
--- --- ---
| id| A| B|
--- --- ---
| a1| 10| 15|
--- --- ---
| a1| 10| 18|
--- --- ---
| a1| 12| 15|
--- --- ---
| a1| 12| 18|
--- --- ---
| a1| 18| 18|
--- --- ---
| a1| 18| 24|
--- --- ---
| a1| 18| 27|
--- --- ---
| a1| 55| 57|
--- --- ---
| a1| 55| 59|
--- --- ---
| a1| 55| 61|
--- --- ---
| a1| 67| 75|
--- --- ---
And I would like the output to be:
--- --- ---
| id| A| B|
--- --- ---
| a1| 10| 15|
--- --- ---
| a1| 12| 18|
--- --- ---
| a1| 18| 24|
--- --- ---
| a1| 55| 57|
--- --- ---
| a1| 67| 75|
--- --- ---
The final df would keep or assign to column A the minimum value in B that has not been used by the row above. Thus the first row would keep/assign the minimum value 15 in column B to column A, and then compare the next row and determine the next largest number not equal to the value assigned above it (in this case the minimum value of 15 in column B cannot be assigned to 12 in column A since it had been assigned to 10 in column A in the row above it and thus 18 in column B gets assigned to 12 in column A). I have tried to obtain the distinct values in A and in B per id, assign row_number() and then join, but the problem is using row_number skews the data and loses records on the join such as 55 and 67 in column A and 57 and 75 in column B:
dfa = df.select('id', 'A').distinct()
dfa = dfa.withColumn('rnk', row_number().over(Window.partitionBy('id').orderBy('A')))
dfb = df.select('id', 'B').distinct()
dfb = dfb.withColumnRenamed('id', 'id1')
dfb = dfb.withColumn('rnk1', row_number().over(Window.partitionBy('id').orderBy('B')))
a_join_b = dfa.join(dfb,
((dfa['id'] == dfb['id1'])
& (dfa['rnk'] == dfb['rnk1'])
& (dfb['B'] > dfa['A'])),
'inner').drop('id1', 'rnk1')
--- --- --- ---
| id| A| B|rnk|
--- --- --- ---
| a1| 10| 15| 1|
--- --- --- ---
| a1| 12| 18| 2|
--- --- --- ---
| a1| 18| 24| 3|
--- --- --- ---
CodePudding user response:
I was able to do the operation using arrays and aggregate
function.
data_sdf. \
groupBy('id', 'c1'). \
agg(func.collect_list('c2').alias('c2_arr')). \
withColumn('c1_c2arr_struct', func.struct('c1', 'c2_arr')). \
groupBy('id'). \
agg(func.array_sort(func.collect_list('c1_c2arr_struct')).alias('c1_c2arr_struct_arr')). \
withColumn('c1_c2arr_struct_arr_new',
func.expr('''
aggregate(slice(c1_c2arr_struct_arr, 2, size(c1_c2arr_struct_arr)),
array(struct(c1_c2arr_struct_arr[0].c1 as c1, array_min(c1_c2arr_struct_arr[0].c2_arr) as c2_arr_custom_min)),
(x, y) -> array_union(x,
array(struct(y.c1 as c1,
array_min(array_remove(y.c2_arr, element_at(x, -1).c2_arr_custom_min)) as c2_arr_custom_min
)
)
)
)
''')
). \
selectExpr('id', 'inline(c1_c2arr_struct_arr_new)'). \
show(truncate=False)
# --- --- -----------------
# |id |c1 |c2_arr_custom_min|
# --- --- -----------------
# |a1 |10 |15 |
# |a1 |12 |18 |
# |a1 |18 |24 |
# |a1 |55 |57 |
# |a1 |67 |75 |
# --- --- -----------------
Explanation
Given the following input data
--- --- ---
| id| c1| c2|
--- --- ---
| a1| 10| 15|
| a1| 10| 18|
| a1| 12| 15|
| a1| 12| 18|
| a1| 18| 18|
| a1| 18| 24|
| a1| 18| 27|
| a1| 55| 57|
| a1| 55| 59|
| a1| 55| 61|
| a1| 67| 75|
--- --- ---
You create a list of the values you want the min for within each id
and A
partition. And, then create a struct with this list field and the A
column.
data_sdf. \
groupBy('id', 'c1'). \
agg(func.collect_list('c2').alias('c2_arr')). \
withColumn('c1_c2arr_struct', func.struct('c1', 'c2_arr')). \
show(truncate=False)
# --- --- ------------ ------------------
# |id |c1 |c2_arr |c1_c2arr_struct |
# --- --- ------------ ------------------
# |a1 |18 |[18, 24, 27]|{18, [18, 24, 27]}|
# |a1 |55 |[57, 59, 61]|{55, [57, 59, 61]}|
# |a1 |67 |[75] |{67, [75]} |
# |a1 |10 |[15, 18] |{10, [15, 18]} |
# |a1 |12 |[15, 18] |{12, [15, 18]} |
# --- --- ------------ ------------------
Create a list of these structs for each id
.
data_sdf. \
groupBy('id', 'c1'). \
agg(func.collect_list('c2').alias('c2_arr')). \
withColumn('c1_c2arr_struct', func.struct('c1', 'c2_arr')). \
groupBy('id'). \
agg(func.array_sort(func.collect_list('c1_c2arr_struct')).alias('c1_c2arr_struct_arr')). \
show(truncate=False)
# --- ------------------------------------------------------------------------------------
# |id |c1_c2arr_struct_arr |
# --- ------------------------------------------------------------------------------------
# |a1 |[{10, [15, 18]}, {12, [15, 18]}, {18, [18, 24, 27]}, {55, [57, 59, 61]}, {67, [75]}]|
# --- ------------------------------------------------------------------------------------
Use aggregate
on the list of structs. The aggregate
function works like python's reduce
and it takes in a lambda function which is applied recursively to all elements of an array (which in this case are structs).
The idea of the function to be applied is
- calculate the initial
min
for the firstA
group- the second input to
aggregate
- the second input to
- remove the aforementioned min value from the next
A
group- the
array_remove
in the third input ofaggregate
- the
- calculate the
min
of the remaining values in thatA
group- the
array_min
in the third input ofaggregate
- the
data_sdf. \
groupBy('id', 'c1'). \
agg(func.collect_list('c2').alias('c2_arr')). \
withColumn('c1_c2arr_struct', func.struct('c1', 'c2_arr')). \
groupBy('id'). \
agg(func.array_sort(func.collect_list('c1_c2arr_struct')).alias('c1_c2arr_struct_arr')). \
withColumn('c1_c2arr_struct_arr_new',
func.expr('''
aggregate(slice(c1_c2arr_struct_arr, 2, size(c1_c2arr_struct_arr)),
array(struct(c1_c2arr_struct_arr[0].c1 as c1, array_min(c1_c2arr_struct_arr[0].c2_arr) as c2_arr_custom_min)),
(x, y) -> array_union(x, array(struct(y.c1 as c1, array_min(array_remove(y.c2_arr, x[size(x)-1].c2_arr_custom_min)) as c2_arr_custom_min)))
)
''')
). \
show(truncate=False)
# --- ------------------------------------------------------------------------------------ --------------------------------------------------
# |id |c1_c2arr_struct_arr |c1_c2arr_struct_arr_new |
# --- ------------------------------------------------------------------------------------ --------------------------------------------------
# |a1 |[{10, [15, 18]}, {12, [15, 18]}, {18, [18, 24, 27]}, {55, [57, 59, 61]}, {67, [75]}]|[{10, 15}, {12, 18}, {18, 24}, {55, 57}, {67, 75}]|
# --- ------------------------------------------------------------------------------------ --------------------------------------------------
The inline
SQL function explodes the array of structs and creates new columns using the struct fields.