Similar to this question (Scala), but I need combinations in PySpark (pair combinations of array column).
Example input:
df = spark.createDataFrame(
[([0, 1],),
([2, 3, 4],),
([5, 6, 7, 8],)],
['array_col'])
Expected output:
------------ ------------------------------------------------
|array_col |out |
------------ ------------------------------------------------
|[0, 1] |[[0, 1]] |
|[2, 3, 4] |[[2, 3], [2, 4], [3, 4]] |
|[5, 6, 7, 8]|[[5, 6], [5, 7], [5, 8], [6, 7], [6, 8], [7, 8]]|
------------ ------------------------------------------------
CodePudding user response:
pandas_udf
is an efficient and concise approach in PySpark.
from pyspark.sql import functions as F
import pandas as pd
from itertools import combinations
@F.pandas_udf('array<array<int>>')
def pudf(c: pd.Series) -> pd.Series:
return c.apply(lambda x: list(combinations(x, 2)))
df = df.withColumn('out', pudf('array_col'))
df.show(truncate=0)
# ------------ ------------------------------------------------
# |array_col |out |
# ------------ ------------------------------------------------
# |[0, 1] |[[0, 1]] |
# |[2, 3, 4] |[[2, 3], [2, 4], [3, 4]] |
# |[5, 6, 7, 8]|[[5, 6], [5, 7], [5, 8], [6, 7], [6, 8], [7, 8]]|
# ------------ ------------------------------------------------
Note: in some systems, instead of 'array<array<int>>'
you may need to provide types from pyspark.sql.types
, e.g. ArrayType(ArrayType(IntegerType()))
.
CodePudding user response:
Native Spark approach. I've translated this answer to PySpark.
Python 3.8 (walrus :=
operator for "array_col"
which is repeated several times in this script):
from pyspark.sql import functions as F
df = df.withColumn(
"out",
F.filter(
F.transform(
F.flatten(F.transform(
c:="array_col",
lambda x: F.arrays_zip(F.array_repeat(x, F.size(f"{c}")), f"{c}")
)),
lambda x: F.array(x["0"], x[f"{c}"])
),
lambda x: x[0] < x[1]
)
)
df.show(truncate=0)
# ------------ ------------------------------------------------
# |array_col |out |
# ------------ ------------------------------------------------
# |[0, 1] |[[0, 1]] |
# |[2, 3, 4] |[[2, 3], [2, 4], [3, 4]] |
# |[5, 6, 7, 8]|[[5, 6], [5, 7], [5, 8], [6, 7], [6, 8], [7, 8]]|
# ------------ ------------------------------------------------
Alternative without walrus operator:
from pyspark.sql import functions as F
df = df.withColumn(
"out",
F.filter(
F.transform(
F.flatten(F.transform(
"array_col",
lambda x: F.arrays_zip(F.array_repeat(x, F.size("array_col")), "array_col")
)),
lambda x: F.array(x["0"], x["array_col"])
),
lambda x: x[0] < x[1]
)
)
Alternative for Spark 2.4
from pyspark.sql import functions as F
df = df.withColumn(
"out",
F.expr("""
filter(
transform(
flatten(transform(
array_col,
x -> arrays_zip(array_repeat(x, size(array_col)), array_col)
)),
x -> array(x["0"], x["array_col"])
),
x -> x[0] < x[1]
)
""")
)