Home > Software design >  Find top n results for multiple fields in Spark dataframe
Find top n results for multiple fields in Spark dataframe

Time:09-17

I have a dataframe like this one:

name  field1  field2  field3
a     4       10      8 
b     5       0       11
c     10      7       4
d     0       1       5

I need to find top 3 names for each field.

Expected output:

top3-field1  top3-field2  top3-field3
c            a            b
b            c            a
a            d            d

So, I tried to sort field(n) column values, limit top 3 results and generate new columns using withColumn method, like this:

df1 = df.orderBy(f.col("field1").desc(), "name") \
.limit(3) \
.withColumn("top3-field1", df["name"]) \
.select("top3-field1", "field1")

With this approach I have to create different dataframes for each field(n), and then join them to get the result as described above. I feel that there must be better solution for this problem. Hope someone can give me suggestions

CodePudding user response:

You can first stack the df, then get the rank descending, then filter out rank less than or equal to 3, finally pivot the names:

Note that I am using this function in my code to make stacking a little easier in typing per se:


from pyspark.sql import functions as F, Window as W #imports

w = W.partitionBy("col").orderBy(F.desc("values"))
out = (df.selectExpr("name",stack_multiple_col(df,df.columns[1:]))
         .withColumn("Rnk",F.dense_rank().over(w))
         .where("Rnk<=3").groupBy("Rnk").pivot("col").agg(F.first("name")))

out.show()

 --- ------ ------ ------ 
|Rnk|field1|field2|field3|
 --- ------ ------ ------ 
|  1|     c|     a|     b|
|  2|     b|     c|     a|
|  3|     a|     d|     d|
 --- ------ ------ ------ 

If you are not willing to use the function, you can write the same as :

w = W.partitionBy("col").orderBy(F.desc("values"))
out = (df.selectExpr("name",
'stack(3,"field1",field1,"field2",field2,"field3",field3) as (col,values)')
 .withColumn("Rnk",F.dense_rank().over(w))
.where("Rnk<=3").groupBy("Rnk").pivot("col").agg(F.first("name")))

Full code:

def stack_multiple_col(df,cols=df.columns,output_columns=["col","values"]):
  """stacks multiple columns in a dataframe, 
     takes all columns by default unless passed a list of values"""
  return (f"""stack({len(cols)},{','.join(map(','.join,
         (zip([f'"{i}"' for i in cols],cols))))}) as ({','.join(output_columns)})""")


w = W.partitionBy("col").orderBy(F.desc("values"))
out = (df.selectExpr("name",stack_multiple_col(df,df.columns[1:]))
         .withColumn("Rnk",F.dense_rank().over(w))
      .where("Rnk<=3").groupBy("Rnk").pivot("col").agg(F.first("name")))

out.show()
  • Related