I have the following dataframe and need to compute the standard deviation of each vector in the column salary.
dept_name | salary |
---|---|
Sales | [30, 36] |
Finance | [10, 98] |
Marketing | [20, 22] |
IT | [40, 90] |
CodePudding user response:
Option 1 - using UDF
- Create a function to calculate the standard deviation for a python list.
- Assign that function to a pyspark sql
udf
. - Create a new
stdev_salary
column that applies theudf
to thesalary
column usingwithColumn
.
# imports required for this solution
from pyspark.sql.types import *
from pyspark.sql.functions import udf
# calculate std dev for list input
def stdev_list(salary_list):
mean = sum(salary_list) / len(salary_list)
variance = sum([((x - mean) ** 2) for x in salary_list]) / len(salary_list)
stdev = variance ** 0.5
return stdev
# apply std dev function to pyspark sql udf
stdev_udf = udf(stdev_list, FloatType() )
# make a new column using the pyspark sql udf
df = df.withColumn('stdev_salary',stdev_udf('salary'))
More about the pyspark sql udf
function here: https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.udf.html
Option 2 - not using UDF
- First
explode
thesalary
column so each salary item is represented on a new row
from pyspark.sql import functions as F
df_exploded = df.select('dept_name', 'salary', F.explode('salary').alias('salary_item'))
- Then, calculate the standard deviation using the
salary_item
column while grouping bydept_name
andsalary
df_final = df_exploded.groupBy('dept_name', 'salary').agg(F.stddev('salary_item').alias('stddev_salary'))