I have a Spark dataframe that looks something like this:
columns = ["object_type", "object_name"]
data = [("galaxy", "andromeda,milky way,condor,andromeda"),
("planet", "mars,jupiter,venus,mars,saturn,venus,earth,mars,venus,earth"),
("star", "mira,sun,altair,sun,sirius,rigel,mira,sirius,aldebaran"),
("natural satellites", "moon,io,io,elara,moon,kale,titan,kale,phobos,titan,europa")]
init_df = spark.createDataFrame(data).toDF(*columns)
init_df.show(truncate = False)
------------------ -----------------------------------------------------------
|object_type |object_name |
------------------ -----------------------------------------------------------
|galaxy |andromeda,milky way,condor,andromeda |
|planet |mars,jupiter,venus,mars,saturn,venus,earth,mars,venus,earth|
|star |mira,sun,altair,sun,sirius,rigel,mira,sirius,aldebaran |
|natural satellites|moon,io,io,elara,moon,kale,titan,kale,phobos,titan,europa |
------------------ -----------------------------------------------------------
I need to create a new column with the most frequent words from the object_name
column using PySpark.
Conditions:
- if there is one dominant word in the row (mode = 1), then choose this word as most frequent (like
"andromeda"
in the first row) - if there are two dominant words in the row that occur the equal number of times (mode = 2), then select both these words (like
"mars"
and"venus"
in the second row - they occur by 3 times, while the rest of the words are less common) - if there are three dominant words in the row that occur an equal number of times, then pick all these three words (like
"mira"
,"sun"
and"sirius"
which occur by 2 times, while the rest of the words only once) - if there are four or more dominant words in the row that occur an equal number of times (like in the fourth row), then set the
"many objects"
flag.
Expected output:
----------------- ----------------------------------------------------------- ---------------
|object_type |object_name |most_frequent |
----------------- ----------------------------------------------------------- ---------------
|galaxy |andromeda,milky way,condor,andromeda |andromeda |
|planet |mars,jupiter,venus,mars,saturn,venus,earth,mars,venus,earth|mars,venus |
|star |mira,sun,altair,sun,sirius,rigel,mira,sirius,aldebaran |mira,sun,sirius|
|natural satellite|moon,io,io,elara,moon,kale,titan,kale,phobos,titan,europa |many objects |
----------------- ----------------------------------------------------------- ---------------
I'll be very grateful for any advice!
CodePudding user response:
You can try this,
res_df = init_df.withColumn("list_obj", F.split(F.col("object_name"),",")) \
.withColumn("most_frequent", F.udf(lambda x: ', '.join([mitem[1] for mitem in zip((x.count(item) for item in set(x)),set(x)) if mitem[0] == max((x.count(item) for item in set(x)))]))(F.col("list_obj"))) \
.drop("list_obj")
res_df.show(truncate=False)
------------------ ----------------------------------------------------------- ---------------------
|object_type |object_name |most_frequent |
------------------ ----------------------------------------------------------- ---------------------
|galaxy |andromeda,milky way,condor,andromeda |andromeda |
|planet |mars,jupiter,venus,mars,saturn,venus,earth,mars,venus,earth|venus, mars |
|star |mira,sun,altair,sun,sirius,rigel,mira,sirius,aldebaran |sirius, mira, sun |
|natural satellites|moon,io,io,elara,moon,kale,titan,kale,phobos,titan,europa |moon, kale, titan, io|
------------------ ----------------------------------------------------------- ---------------------
EDIT:
According to OP's suggestion, we can achieve their desired output by doing something like this,
from pyspark.sql.types import *
res_df = init_df.withColumn("list_obj", F.split(F.col("object_name"),",")) \
.withColumn("most_frequent", F.udf(lambda x: [mitem[1] for mitem in zip((x.count(item) for item in set(x)),set(x)) if mitem[0] == max((x.count(item) for item in set(x)))], ArrayType(StringType()))(F.col("list_obj"))) \
.withColumn("most_frequent", F.when(F.size(F.col("most_frequent")) >= 4, F.lit("many objects")).otherwise(F.concat_ws(", ", F.col("most_frequent")))) \
.drop("list_obj")
res_df.show(truncate=False)
------------------ ----------------------------------------------------------- -----------------
|object_type |object_name |most_frequent |
------------------ ----------------------------------------------------------- -----------------
|galaxy |andromeda,milky way,condor,andromeda |andromeda |
|planet |mars,jupiter,venus,mars,saturn,venus,earth,mars,venus,earth|venus, mars |
|star |mira,sun,altair,sun,sirius,rigel,mira,sirius,aldebaran |sirius, mira, sun|
|natural satellites|moon,io,io,elara,moon,kale,titan,kale,phobos,titan,europa |many objects |
------------------ ----------------------------------------------------------- -----------------
CodePudding user response:
Try this:
from pyspark.sql import functions as psf
from pyspark.sql.window import Window
columns = ["object_type", "object_name"]
data = [("galaxy", "andromeda,milky way,condor,andromeda"),
("planet", "mars,jupiter,venus,mars,saturn,venus,earth,mars,venus,earth"),
("star", "mira,sun,altair,sun,sirius,rigel,mira,sirius,aldebaran"),
("natural satellites", "moon,io,io,elara,moon,kale,titan,kale,phobos,titan,europa")]
init_df = spark.createDataFrame(data).toDF(*columns)
# unpivot the object name and count
df_exp = init_df.withColumn('object_name_exp', psf.explode(psf.split('object_name',',')))
df_counts = df_exp.groupBy('object_type', 'object_name_exp').count()
window_spec = Window.partitionBy('object_type').orderBy(psf.col('count').desc())
df_ranked = df_counts.withColumn('rank', psf.dense_rank().over(window_spec))
# rank the counts, keeping the top ranked object names
df_top_ranked = df_ranked.filter(psf.col('rank')==psf.lit(1)).drop('count')
# count the number of top ranked object names
df_top_counts = df_top_ranked.groupBy('object_type', 'rank').count()
# join these back to the original object names
df_with_counts = df_top_ranked.join(df_top_counts, on='object_type', how='inner')
# implement the rules whether to retain the reference to the object name or state 'many objects'
df_most_freq = df_with_counts.withColumn('most_frequent'
, psf.when(psf.col('count')<=psf.lit(3), psf.col('object_name_exp')).otherwise(psf.lit('many objects'))
)
# collect the object names retained back into and array and de-duplicate them
df_results = df_most_freq.groupBy('object_type').agg(psf.array_distinct(psf.collect_list('most_frequent')).alias('most_frequent'))
# show output
df_results.show()
------------------ -------------------
| object_type| most_frequent|
------------------ -------------------
| galaxy| [andromeda]|
|natural satellites| [many objects]|
| planet| [mars, venus]|
| star|[sirius, mira, sun]|
------------------ -------------------