Home > Software design >  Pyspark - Looking to find indexes on the top N largest values in an array column
Pyspark - Looking to find indexes on the top N largest values in an array column

Time:12-27

I'm looking to replace the functionality of the following numpy command:

top_n_idx = np.argsort(cosine_sim[idx])[::-1][1:11]

Sample Data:

array_col

[0.1,0.5,0.2,0.5,0.9]
[0.1,0.9,0.5,0.2,0.35]

Here is the code I have so far:

df.select("array_col", F.slice(F.sort_array(F.col("array_col"), asc=False), 1, 3).alias("top_scores")).show()

array_col               top_scores

[0.1,0.5,0.2,0.55,0.9]  [0.9, 0.55, 0.5]
[0.1,0.9,0.5,0.2,0.35]  [0.9, 0.5, 0.35]

Now, what I would like to do is find the indexes in array_col that correspond to the `top_scores" columns.

array_col               top_scores.       top_score_idx

[0.1,0.5,0.2,0.55,0.9]  [0.9, 0.55, 0.5]  [5, 4, 2]
[0.1,0.9,0.5,0.2,0.35]  [0.9, 0.5, 0.35]  [2, 3, 5]

I will ultimately use top_score_idx to grab the corresponds indexes in another array columnn.

CodePudding user response:

For Spark 2.4 , use array_position and transform functions to transform the top_scores array and get their 1-based indexes in the array_col column.

df \
.select("array_col", F.slice(F.sort_array(F.col("array_col"), asc=False), 1, 3).alias("top_scores")) \
.withColumn("top_score_idx", F.expr("transform(top_scores, v -> array_position(array_col, v))")) \
.show()

#  -------------------------- ---------------- ------------- 
# |array_col                 |top_scores      |top_score_idx|
#  -------------------------- ---------------- ------------- 
# |[0.1, 0.5, 0.2, 0.55, 0.9]|[0.9, 0.55, 0.5]|[5, 4, 2]    |
# |[0.1, 0.9, 0.5, 0.2, 0.35]|[0.9, 0.5, 0.35]|[2, 3, 5]    |
#  -------------------------- ---------------- ------------- 
  • Related