Home > Enterprise >  Sum of null and duplicate values across multiple columns in pyspark data framew
Sum of null and duplicate values across multiple columns in pyspark data framew

Time:08-27

I have a data frame like below in pyspark

data = [
("James","CA",None), (None,"AC",None),
("Ram","AC",200.0), ("Ram",None,None)
]
df = spark.createDataFrame(data,["name","state","number"])
df.show()

 ----- ----- ------ 
| name|state|number|
 ----- ----- ------ 
|James|   CA|  null|
| null|   AC|  null|
|  Ram|   AC| 200.0|
|  Ram| null|  null|
 ----- ----- ------ 

Below is what I am trying to achieve

I want to count the number of nulls in each column and then capture the count of nulls across all the columns as variable

I have done like below

from pyspark.sql.functions import col,isnan, when, count
df_null = df.select([count(when(col(c).isNull(), c)).alias(c) for c in df.columns])

df_null.show()

I get the below result

 ----- ----- ------ 
| name|state|number|
 ----- ----- ------ 
|    1|    1|     3|
 ----- ----- ------ 

What I want to do is capture 1 1 3 as a variable.

I have done like below

n_1 = df_null.collect()[0][0]
s_1 = df_null.collect()[0][1]   
nu_1 = df_null.collect()[0][2]

null_count = n_1   s_1   nu_1

Also I want to find duplicates of each column and then capture the count of duplicates across all the columns as variable

I have done like below

list_1 = ['name']

df_1 = df.groupby(list_1).count().where('count > 1')

 ----- ----- 
| name|count|
 ----- ----- 
|  Ram|    2|
 ----- ----- 

list_2 = ['state']

df_2 = df.groupby(list_2).count().where('count > 1')

 ------ ----- 
| state|count|
 ------ ----- 
|    AC|    2|
 ------ ----- 

list_df1 = df_1.collect()[0][1]
list_df2 = df_2.collect()[0][1]

dup_count = list_df1   list_df2

I am able to achieve what I want but trying to see if there is a better way to achieve

CodePudding user response:

You're doing 3 collects of df_null, which can be reduced to a single collect. Multiple actions on the same dataframe will retrigger its lineage.

nulls_per_col = data_sdf. \
    select(*[func.sum(func.col(k).isNull().cast('int')).alias(k) for k in data_sdf.columns]). \
    collect()

print(nulls_per_col)
# [Row(name=1, state=1, number=3)]

null_count = reduce(lambda x, y: x   y, [nulls_per_col[0][k] for k in data_sdf.columns])

print(null_count)
# 5
  • Related