I'm trying to calculate the Standard Deviation of a column in a DataFrame, but when tried I get a failure message as below:
[info] - should return the standard deviation for all columns in a DataFrame *** FAILED *** (51 milliseconds)
[info] org.apache.spark.sql.AnalysisException: cannot resolve '`value_6`' given input columns: [stddev_samp(value_6)];
[info] 'Project ['value_6]
[info] - Aggregate [stddev_samp(value_6#131) AS stddev_samp(value_6)#151]
[info] - Project [coalesce(nanvl(value_6#60, cast(null as double)), cast(0 as double)) AS value_6#131]
[info] - Project [value_6#60]
[info] - Project [_1#39 AS id#54, _2#40 AS value_1#55, _3#41 AS value_2#56, _4#42 AS value_3#57, _5#43 AS value_4#58, _6#44 AS value_5#59, _7#45 AS value_6#60]
[info] - LocalRelation [_1#39, _2#40, _3#41, _4#42, _5#43, _6#44, _7#45]
[info] at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42)
[info] at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$$nestedInanonfun$checkAnalysis$1$2.applyOrElse(CheckAnalysis.scala:155)
[info] at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$$nestedInanonfun$checkAnalysis$1$2.applyOrElse(CheckAnalysis.scala:152)
[info] at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformUp$2(TreeNode.scala:341)
[info] at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:73)
[info] at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:341)
[info] at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$transformExpressionsUp$1(QueryPlan.scala:104)
[info] at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$mapExpressions$1(QueryPlan.scala:116)
[info] at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:73)
[info] at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpression$1(QueryPlan.scala:116)
Here is what I have:
def standardDeviationForColumn(df: DataFrame, columnName: String): DataFrame =
df.select(columnName).na.fill(0).agg(stddev(columnName))
Here is how I call it:
assert(DataFrameUtils.standardDeviationForColumn(randomNumericTestData, "value_6").select("value_6").first().getDouble(0) === 1, "TODO")
What is that I'm doing wrong here? Here is my DataFrame:
val randomNumericTestData: DataFrame = Seq(
(1, 1, 10.0, 10.0,10.0,10.0,10.0),
(2, 0, 12.0, 12.0,12.0,12.0,12.0),
(3, 1, 13.0, 13.0,13.0,13.0,13.0),
(4, 1, 14.0, 14.0,14.0,14.0,14.0),
(5, 0, 12.5, 12.5,12.5,12.5,12.5),
(6, 1, 11.5, 11.5,11.5,11.5,11.5),
(7, 0, 17.5, 17.5,17.5,17.5,17.5),
(8, 0, 13.6, 13.6,13.6,13.6,13.6),
(9, 1, 14.2, 14.2,14.2,14.2,14.2)
).toDF("id", "value_1", "value_2", "value_3", "value_4", "value_5", "value_6")
CodePudding user response:
The clue is in the error message: org.apache.spark.sql.AnalysisException: cannot resolve 'value_6' given input columns: [stddev_samp(value_6)];
. Spark can't find a column with the name value_6
.
When you call ....agg(stddev(columnName))
you are given a new column called stddev(columnName)
in the output DataFrame
. You need to rename the aggregate column:
def standardDeviationForColumn(df: DataFrame, columnName: String): DataFrame =
df.select(columnName).na.fill(0).agg(stddev(columnName) as columnName)