I have a dataframe that looks like below
accuracy
--------
91.0
92.0
73.0
72.0
88.0
I am using aggregate, count and collect to get the column sum which is taking too much time. Below is my code
total_count = df.count()
total_sum=df.agg({'accuracy': 'sum'}).collect()
total_sum_val = [i[0] for i in total_sum]
acc_top_k = (total_sum_val[0]/total_count)*100
Is there any alternative method to get the mean accuracy in PySpark?
CodePudding user response:
First, you can aggregate the column values and calculate the average. Then, extract it into the variable.
df = df.agg(F.avg('accuracy'))
acc_top_k = df.head()[0] * 100
Full test:
from pyspark.sql import functions as F
df = spark.createDataFrame([(91.0,), (92.0,), (73.0,), (72.0,), (88.0,)], ['accuracy'])
df = df.agg(F.avg('accuracy'))
acc_top_k = df.head()[0] * 100
print(acc_top_k)
# 8220.0
If you prefer, you can use your method too:
df = df.agg({'accuracy': 'avg'})