Home > Mobile >  Groupby and get list with sorted value from other column pyspark
Groupby and get list with sorted value from other column pyspark

Time:10-14

Hey I have dataframe like this:

 ---------- ---------- ------------------ 
|      id_A|      id_B|   Distance       | 
 ---------- ---------- ------------------ 
| 120745612| 122913167|0.6142857142857143|
|1243257970| 370926553|0.8061224489795918|
|1305652409| 253051944|0.8252427184466019|
|1350805455| 311286173|0.5789473684210527|
|1544864070| 390580289|0.7894736842105263|
| 164533143| 763751752|0.8153846153846154|
|1683553267| 787287056|0.9117647058823529|
| 175951349| 175951349|               0.0|

Now I want to groupby using id_A and get list of id_B in Ascending order of Distance value. Means where Distance was least I want respective value of id_B first in list and so on.

Expected_Out:

|  id_A  |   id_B  |
175951349  [175951349, 390580289, ...]

CodePudding user response:

First use the collect_list function to construct an array of struct, then use array_sort to sort according to the Distance field in the struct, and finally use the transform function to convert the array to the format you need.

df = df.groupBy('id_A').agg(
    F.expr("""
        transform(
            array_sort(
                collect_list(struct(id_B, Distance)),
                (l, r) -> case when l.Distance < r.Distance then -1 when l.Distance > r.Distance then 1 else 0 end
            ),
            x -> x.id_B
        )
    """)
)

CodePudding user response:

You can do this by first combining id_B & Distance into a Struct and sort based on Distance using array_sort and finally extract the required field

Data Preparation

s = StringIO("""
id_A,id_B,Distance
120745612,122913167,0.6142857142857143
1243257970,370926553,0.8061224489795918
1305652409,253051944,0.8252427184466019
1350805455,311286173,0.5789473684210527
1544864070,390580289,0.7894736842105263
164533143,763751752,0.8153846153846154
1683553267,787287056,0.9117647058823529
175951349,175951349,0.0
1683553267,787287056,0.67217647058823529
1683553267,787287056,0.51236647058823529
1683553267,787287056,0.98176470588235291
""")

### I have manually added the last 3 records to demonstrate the working

df = pd.read_csv(s,delimiter=',')

sparkDF = sql.createDataFrame(df).orderBy('id_A')

sparkDF.show()

 ---------- --------- ------------------ 
|      id_A|     id_B|          Distance|
 ---------- --------- ------------------ 
| 120745612|122913167|0.6142857142857143|
| 164533143|763751752|0.8153846153846154|
| 175951349|175951349|               0.0|
|1243257970|370926553|0.8061224489795918|
|1305652409|253051944|0.8252427184466019|
|1350805455|311286173|0.5789473684210527|
|1544864070|390580289|0.7894736842105263|
|1683553267|787287056|0.9117647058823528|
|1683553267|787287056|0.6721764705882352|
|1683553267|787287056|0.9817647058823528|
|1683553267|787287056|0.5123664705882351|
 ---------- --------- ------------------ 

Struct - Array Sort

sparkDF.groupby("id_A") \
       .agg(F.sort_array(F.collect_set(F.struct("Distance","id_B"))).alias("collected_list")) \
       .withColumn("sorted_list",F.col("collected_list.id_B")) \
       .drop("collected_list")\
       .show(truncate=False)

 ---------- -------------------------------------------- 
|id_A      |sorted_list                                 |
 ---------- -------------------------------------------- 
|120745612 |[122913167]                                 |
|164533143 |[763751752]                                 |
|175951349 |[175951349]                                 |
|1243257970|[370926553]                                 |
|1305652409|[253051944]                                 |
|1350805455|[311286173]                                 |
|1544864070|[390580289]                                 |
|1683553267|[787287056, 787287056, 787287056, 787287056]|
 ---------- -------------------------------------------- 
  • Related