I have a dataframe like this one:
------- -----------
| grupos| valores|
------- -----------
|grupo_1| [1, 2, 3]|
|grupo_1| [1, 2, 5]|
|grupo_1| [1, 2, 6]|
|grupo_2| [1, 2, 7]|
|grupo_2| [12, 2, 7]|
|grupo_2| [32, 2, 7]|
------- -----------
I need something to groupby and show only the intersection of the lists, like:
------- -----------
| grupos| valores|
------- -----------
|grupo_1| [1, 2]|
|grupo_2| [2, 7]|
------- -----------
Anyone can help me?
CodePudding user response:
Group by grupos
column and collect list of valores
. Then using aggregate
with array_intersect
functions, you find the intersection of all sub arrays:
from pyspark.sql import functions as F
df = spark.createDataFrame([
("grupo_1", [1, 2, 3]), ("grupo_1", [1, 2, 5]),
("grupo_1", [1, 2, 6]), ("grupo_2", [1, 2, 7]),
("grupo_2", [12, 2, 7]), ("grupo_2", [32, 2, 7])
], ["grupos", "valores"])
df1 = df.groupBy("grupos").agg(
F.collect_set("valores").alias("valores")
).withColumn(
"valores",
F.expr("aggregate(valores, valores[0], (acc, x) -> array_intersect(acc, x))")
)
df1.show()
# ------- -------
#|grupos |valores|
# ------- -------
#|grupo_1|[1, 2] |
#|grupo_2|[2, 7] |
# ------- -------