Home > OS >  How to agg a pyspark dataframe and show the intersection of the lists in a column?
How to agg a pyspark dataframe and show the intersection of the lists in a column?

Time:02-17

I have a dataframe like this one:

 ------- ----------- 
| grupos|    valores|
 ------- ----------- 
|grupo_1|  [1, 2, 3]|
|grupo_1|  [1, 2, 5]|
|grupo_1|  [1, 2, 6]|
|grupo_2|  [1, 2, 7]|
|grupo_2| [12, 2, 7]|
|grupo_2| [32, 2, 7]|
 ------- ----------- 

I need something to groupby and show only the intersection of the lists, like:

 ------- ----------- 
| grupos|    valores|
 ------- ----------- 
|grupo_1|     [1, 2]|
|grupo_2|     [2, 7]|
 ------- ----------- 

Anyone can help me?

CodePudding user response:

Group by grupos column and collect list of valores. Then using aggregate with array_intersect functions, you find the intersection of all sub arrays:

from pyspark.sql import functions as F

df = spark.createDataFrame([
    ("grupo_1", [1, 2, 3]), ("grupo_1", [1, 2, 5]),
    ("grupo_1", [1, 2, 6]), ("grupo_2", [1, 2, 7]),
    ("grupo_2", [12, 2, 7]), ("grupo_2", [32, 2, 7])
], ["grupos", "valores"])

df1 = df.groupBy("grupos").agg(
    F.collect_set("valores").alias("valores")
).withColumn(
    "valores",
    F.expr("aggregate(valores, valores[0], (acc, x) -> array_intersect(acc, x))")
)

df1.show()
# ------- ------- 
#|grupos |valores|
# ------- ------- 
#|grupo_1|[1, 2] |
#|grupo_2|[2, 7] |
# ------- ------- 
  • Related