Home > OS >  In pyspark, how to groupBy and collect a list of all distinct structs contained in an array column
In pyspark, how to groupBy and collect a list of all distinct structs contained in an array column

Time:09-08

I have a table with the following schema

root
 |-- match_keys: array (nullable = false)
 |    |-- element: struct (containsNull = false)
 |    |    |-- key: string (nullable = false)
 |    |    |-- entity1: string (nullable = true)
 |    |    |-- entity2: string (nullable = true)
 |-- src: string (nullable = true)
 |-- dst: string (nullable = true)

Here's an example:

src |dst| match_keys
----------------------------------------------------------------------------
a1  |d1 | [{"key": "name", "entity1": "john", "entity2": "john"}]   
a1  |d1 | [{"key": "name", "entity1": "john", "entity2": "john"},
           {"key": "dob", "entity1": "21/01/1999", "entity2": "21/01/1999"}]
a1  |d1 | [{"key": "name", "entity1": "john", "entity2": "john"}
           {"key": "country", "entity1": "IT", "entity2": "IT"}]

What i am looking for is:

src |dst| match_keys
----------------------------------------------------------------------------
a1  |d1 | [{"key": "name", "entity1": "john", "entity2": "john"}, 
           {"key": "dob", "entity1": "21/01/1999", "entity2": "21/01/1999"}, 
           {"key": "country", "entity1": "IT", "entity2": "IT"}

I of course tried:

(df
.groupBy("src", "dst")
.agg(
     F.flatten(F.collect_set(F.col("match_keys")).alias("match_keys"))
     )
).show(truncate=False)

But that results in the below (with all those duplicate structs for the name).

src |dst| match_keys
----------------------------------------------------------------------------
a1  |d1 | [{"key": "name", "entity1": "john", "entity2": "john"},
           {"key": "name", "entity1": "john", "entity2": "john"},  
           {"key": "dob", "entity1": "21/01/1999", "entity2": "21/01/1999"}, 
           {"key": "name", "entity1": "john", "entity2": "john"}, 
           {"key": "country", "entity1": "IT", "entity2": "IT"}

CodePudding user response:

What you need is the array_distinct function.

df = df.groupBy("src", "dst").agg(F.array_distinct(F.flatten(F.collect_set(F.col("match_keys")))).alias("match_keys"))
  • Related