Home > Enterprise >  Find the k most frequent words in each row from PySpark dataframe
Find the k most frequent words in each row from PySpark dataframe

Time:01-19

I have a Spark dataframe that looks something like this:

columns = ["object_type", "object_name"]
data = [("galaxy", "andromeda,milky way,condor,andromeda"),
        ("planet", "mars,jupiter,venus,mars,saturn,venus,earth,mars,venus,earth"), 
        ("star", "mira,sun,altair,sun,sirius,rigel,mira,sirius,aldebaran"),
        ("natural satellites", "moon,io,io,elara,moon,kale,titan,kale,phobos,titan,europa")]
init_df = spark.createDataFrame(data).toDF(*columns)
init_df.show(truncate = False)

 ------------------ ----------------------------------------------------------- 
|object_type       |object_name                                                |
 ------------------ ----------------------------------------------------------- 
|galaxy            |andromeda,milky way,condor,andromeda                       |
|planet            |mars,jupiter,venus,mars,saturn,venus,earth,mars,venus,earth|
|star              |mira,sun,altair,sun,sirius,rigel,mira,sirius,aldebaran     |
|natural satellites|moon,io,io,elara,moon,kale,titan,kale,phobos,titan,europa  |
 ------------------ ----------------------------------------------------------- 

I need to create a new column with the most frequent words from the object_name column using PySpark.
Conditions:

  • if there is one dominant word in the row (mode = 1), then choose this word as most frequent (like "andromeda" in the first row)
  • if there are two dominant words in the row that occur the equal number of times (mode = 2), then select both these words (like "mars" and "venus" in the second row - they occur by 3 times, while the rest of the words are less common)
  • if there are three dominant words in the row that occur an equal number of times, then pick all these three words (like "mira", "sun" and "sirius" which occur by 2 times, while the rest of the words only once)
  • if there are four or more dominant words in the row that occur an equal number of times (like in the fourth row), then set the "many objects" flag.

Expected output:

 ----------------- ----------------------------------------------------------- --------------- 
|object_type      |object_name                                                |most_frequent  |
 ----------------- ----------------------------------------------------------- --------------- 
|galaxy           |andromeda,milky way,condor,andromeda                       |andromeda      |
|planet           |mars,jupiter,venus,mars,saturn,venus,earth,mars,venus,earth|mars,venus     |
|star             |mira,sun,altair,sun,sirius,rigel,mira,sirius,aldebaran     |mira,sun,sirius|
|natural satellite|moon,io,io,elara,moon,kale,titan,kale,phobos,titan,europa  |many objects   |
 ----------------- ----------------------------------------------------------- --------------- 

I'll be very grateful for any advice!

CodePudding user response:

You can try this,

res_df = init_df.withColumn("list_obj", F.split(F.col("object_name"),",")) \
    .withColumn("most_frequent", F.udf(lambda x: ', '.join([mitem[1] for mitem in zip((x.count(item) for item in set(x)),set(x)) if mitem[0] == max((x.count(item) for item in set(x)))]))(F.col("list_obj"))) \
    .drop("list_obj") 

res_df.show(truncate=False)
 ------------------ ----------------------------------------------------------- --------------------- 
|object_type       |object_name                                                |most_frequent        |
 ------------------ ----------------------------------------------------------- --------------------- 
|galaxy            |andromeda,milky way,condor,andromeda                       |andromeda            |
|planet            |mars,jupiter,venus,mars,saturn,venus,earth,mars,venus,earth|venus, mars          |
|star              |mira,sun,altair,sun,sirius,rigel,mira,sirius,aldebaran     |sirius, mira, sun    |
|natural satellites|moon,io,io,elara,moon,kale,titan,kale,phobos,titan,europa  |moon, kale, titan, io|
 ------------------ ----------------------------------------------------------- --------------------- 

EDIT:

According to OP's suggestion, we can achieve their desired output by doing something like this,

from pyspark.sql.types import *

res_df = init_df.withColumn("list_obj", F.split(F.col("object_name"),",")) \
    .withColumn("most_frequent", F.udf(lambda x: [mitem[1] for mitem in zip((x.count(item) for item in set(x)),set(x)) if mitem[0] == max((x.count(item) for item in set(x)))], ArrayType(StringType()))(F.col("list_obj"))) \
    .withColumn("most_frequent", F.when(F.size(F.col("most_frequent")) >= 4, F.lit("many objects")).otherwise(F.concat_ws(", ", F.col("most_frequent")))) \
    .drop("list_obj")

res_df.show(truncate=False)
 ------------------ ----------------------------------------------------------- ----------------- 
|object_type       |object_name                                                |most_frequent    |
 ------------------ ----------------------------------------------------------- ----------------- 
|galaxy            |andromeda,milky way,condor,andromeda                       |andromeda        |
|planet            |mars,jupiter,venus,mars,saturn,venus,earth,mars,venus,earth|venus, mars      |
|star              |mira,sun,altair,sun,sirius,rigel,mira,sirius,aldebaran     |sirius, mira, sun|
|natural satellites|moon,io,io,elara,moon,kale,titan,kale,phobos,titan,europa  |many objects     |
 ------------------ ----------------------------------------------------------- ----------------- 

CodePudding user response:

Try this:

from pyspark.sql import functions as psf
from pyspark.sql.window import Window

columns = ["object_type", "object_name"]
data = [("galaxy", "andromeda,milky way,condor,andromeda"),
        ("planet", "mars,jupiter,venus,mars,saturn,venus,earth,mars,venus,earth"), 
        ("star", "mira,sun,altair,sun,sirius,rigel,mira,sirius,aldebaran"),
        ("natural satellites", "moon,io,io,elara,moon,kale,titan,kale,phobos,titan,europa")]
init_df = spark.createDataFrame(data).toDF(*columns)

# unpivot the object name and count   
df_exp = init_df.withColumn('object_name_exp', psf.explode(psf.split('object_name',',')))
df_counts = df_exp.groupBy('object_type', 'object_name_exp').count()

window_spec = Window.partitionBy('object_type').orderBy(psf.col('count').desc())
df_ranked = df_counts.withColumn('rank', psf.dense_rank().over(window_spec))

# rank the counts, keeping the top ranked object names
df_top_ranked = df_ranked.filter(psf.col('rank')==psf.lit(1)).drop('count')

# count the number of top ranked object names
df_top_counts = df_top_ranked.groupBy('object_type',  'rank').count()

# join these back to the original object names
df_with_counts = df_top_ranked.join(df_top_counts, on='object_type', how='inner')

# implement the rules whether to retain the reference to the object name or state 'many objects'
df_most_freq = df_with_counts.withColumn('most_frequent'
    , psf.when(psf.col('count')<=psf.lit(3), psf.col('object_name_exp')).otherwise(psf.lit('many objects'))
    )

# collect the object names retained back into and array and de-duplicate them
df_results = df_most_freq.groupBy('object_type').agg(psf.array_distinct(psf.collect_list('most_frequent')).alias('most_frequent'))

# show output                                                     
df_results.show()

 ------------------ ------------------- 
|       object_type|      most_frequent|
 ------------------ ------------------- 
|            galaxy|        [andromeda]|
|natural satellites|     [many objects]|
|            planet|      [mars, venus]|
|              star|[sirius, mira, sun]|
 ------------------ ------------------- 
  • Related