Home > Enterprise >  Mean of column in PySpark without using collect
Mean of column in PySpark without using collect

Time:07-13

I have a dataframe that looks like below

accuracy
--------
91.0
92.0
73.0
72.0
88.0

I am using aggregate, count and collect to get the column sum which is taking too much time. Below is my code

total_count = df.count()
total_sum=df.agg({'accuracy': 'sum'}).collect()
total_sum_val = [i[0] for i in total_sum]
acc_top_k = (total_sum_val[0]/total_count)*100

Is there any alternative method to get the mean accuracy in PySpark?

CodePudding user response:

First, you can aggregate the column values and calculate the average. Then, extract it into the variable.

df = df.agg(F.avg('accuracy'))
acc_top_k = df.head()[0] * 100

Full test:

from pyspark.sql import functions as F
df = spark.createDataFrame([(91.0,), (92.0,), (73.0,), (72.0,), (88.0,)], ['accuracy'])

df = df.agg(F.avg('accuracy'))
acc_top_k = df.head()[0] * 100

print(acc_top_k)
# 8220.0

If you prefer, you can use your method too:
df = df.agg({'accuracy': 'avg'})

  • Related