Home > database >  Count unique values for every row in PySpark
Count unique values for every row in PySpark

Time:10-19

I have PySpark DataFrame:

from pyspark.sql.types import *

schema = StructType([
  StructField("col1", StringType()),
  StructField("col2", StringType()),
  StructField("col3", StringType()),
  StructField("col4", StringType()),
])

data = [("aaa", "aab", "baa", "aba"),
        ("aab", "aab", "abc", "daa"), 
        ("aa", "bb", "cc", "dd"),
        (1, "bbb", 2, 2)]

df = spark.createDataFrame(data=data, schema=schema)

I need to calculate the count of unique values in each row. I understand that it should be something like this:

from pyspark.sql.functions import pandas_udf, PandasUDFType, udf

@udf(ArrayType(df.schema))
def substract_unique(row):
    return len(set(row))

df = df.withColumn("test", substract_unique(row))

But I can't understand how to put the whole the row into UDF. All examples I've seen are about either one or some columns or about lambda functions for returning min, mean and max values.

It would be perfect if you can give any example or advice using pandas_udf or UDF.

CodePudding user response:

It was simple...

@udf()
def substract_unique(*values):
    return len(set(values))

cols = df.columns
df = df.withColumn("unique",substract_unique(*cols))

CodePudding user response:

Don't go for udf. It is slow when working with big data. As much as possible use native Spark functions. If not possible, try to create pandas_udf.

  • Native Spark approach:

    from pyspark.sql import functions as F
    
    df = df.withColumn("unique", F.size(F.array_distinct(F.array(df.columns))))
    
    df.show()
    #  ---- ---- ---- ---- ------ 
    # |col1|col2|col3|col4|unique|
    #  ---- ---- ---- ---- ------ 
    # | aaa| aab| baa| aba|     4|
    # | aab| aab| abc| daa|     3|
    # |  aa|  bb|  cc|  dd|     4|
    # |   1| bbb|   2|   2|     3|
    #  ---- ---- ---- ---- ------ 
    
  • pandas_udf approach:

    from pyspark.sql import functions as F
    
    @F.pandas_udf('long')
    def count_unique(d: pd.DataFrame) -> pd.Series:
        return d.nunique(axis=1)
    
    df = df.withColumn("unique", count_unique(F.struct(*df.columns)))
    
    df.show()
    #  ---- ---- ---- ---- ------ 
    # |col1|col2|col3|col4|unique|
    #  ---- ---- ---- ---- ------ 
    # | aaa| aab| baa| aba|     4|
    # | aab| aab| abc| daa|     3|
    # |  aa|  bb|  cc|  dd|     4|
    # |   1| bbb|   2|   2|     3|
    #  ---- ---- ---- ---- ------ 
    
  • Related