Home > OS >  SparkSQL - How to make scalar subquery work without FIRST/MIN/MAX/AVG
SparkSQL - How to make scalar subquery work without FIRST/MIN/MAX/AVG

Time:02-21

As in What does "Correlated scalar subqueries must be Aggregated" mean?, SparkSQL complains for scalar sub-query.

So when catalyst can't make 100% sure just by looking at the SQL statement (without looking at your data) that the sub-query only returns a single row, this exception is thrown. If you are sure that your subquery only gives a single row you can use one of the following aggregation standard functions, so Spark Analyzer is happy:

  • first
  • avg
  • max
  • min

How can I make the SQL query below work?

SELECT
    prerequisite AS prerequisite,
    (SELECT e.description FROM course e WHERE e.course_no = c.prerequisite) as course_name,
    COUNT(*) as cnt
FROM
    course c
WHERE 
    c.prerequisite IS NOT NULL
GROUP BY 
    c.prerequisite
ORDER BY
    prerequisite;

It works in Oracle and returns the result.

PREREQUISITE COURSE_NAME                                               CNT
------------ -------------------------------------------------- ----------
          10 Technology Concepts                                         1
          20 Intro to Information Systems                                5
          25 Intro to Programming                                        2
          80 Programming Techniques                                      2
         120 Intro to Java Programming                                   1
         122 Intermediate Java Programming                               2
         125 Java Developer I                                            1
         130 Intro to Unix                                               2
         132 Basics of Unix Admin                                        1
         134 Advanced Unix Admin                                         1
         140 Systems Analysis                                            1
         204 Intro to SQL                                                1
         220 PL/SQL Programming                                          1
         310 Operating Systems                                           2
         350 Java Developer II                                           2
         420 Database System Principles                                  1

However it fails in SparkSQL with the known error:

AnalysisException: Correlated scalar subqueries must be aggregated

If I put FIRST or MAX, it throws another exception Couldn't find first(description) which seem to be because the scalar query returns only 1 row hence cannot find the first one from it.

SELECT
    prerequisite AS prerequisite,
    (SELECT FIRST(e.description) FROM course e WHERE e.course_no = c.prerequisite) as course_name,
    COUNT(*) as cnt
FROM
    course c
WHERE 
    c.prerequisite IS NOT NULL
GROUP BY 
    c.prerequisite
ORDER BY
    prerequisite
----------
---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
Input In [214], in <module>
      1 query="""
      2 SELECT
      3     prerequisite AS prerequisite,
   (...)
     13     prerequisite
     14 """
---> 15 spark.sql(query).show(truncate=False)

File /opt/spark/spark-3.1.2/python/lib/pyspark.zip/pyspark/sql/dataframe.py:486, in DataFrame.show(self, n, truncate, vertical)
    484     print(self._jdf.showString(n, 20, vertical))
    485 else:
--> 486     print(self._jdf.showString(n, int(truncate), vertical))

File /opt/spark/spark-3.1.2/python/lib/py4j-0.10.9-src.zip/py4j/java_gateway.py:1304, in JavaMember.__call__(self, *args)
   1298 command = proto.CALL_COMMAND_NAME  \
   1299     self.command_header  \
   1300     args_command  \
   1301     proto.END_COMMAND_PART
   1303 answer = self.gateway_client.send_command(command)
-> 1304 return_value = get_return_value(
   1305     answer, self.gateway_client, self.target_id, self.name)
   1307 for temp_arg in temp_args:
   1308     temp_arg._detach()

File /opt/spark/spark-3.1.2/python/lib/pyspark.zip/pyspark/sql/utils.py:111, in capture_sql_exception.<locals>.deco(*a, **kw)
    109 def deco(*a, **kw):
    110     try:
--> 111         return f(*a, **kw)
    112     except py4j.protocol.Py4JJavaError as e:
    113         converted = convert_exception(e.java_exception)

File /opt/spark/spark-3.1.2/python/lib/py4j-0.10.9-src.zip/py4j/protocol.py:326, in get_return_value(answer, gateway_client, target_id, name)
    324 value = OUTPUT_CONVERTER[type](answer[2:], gateway_client)
    325 if answer[1] == REFERENCE_TYPE:
--> 326     raise Py4JJavaError(
    327         "An error occurred while calling {0}{1}{2}.\n".
    328         format(target_id, ".", name), value)
    329 else:
    330     raise Py4JError(
    331         "An error occurred while calling {0}{1}{2}. Trace:\n{3}\n".
    332         format(target_id, ".", name, value))

Py4JJavaError: An error occurred while calling o1058.showString.
: org.apache.spark.sql.catalyst.errors.package$TreeNodeException: Binding attribute, tree: first(description)#13046
    at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:56)
    at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:75)
    at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:74)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDown$1(TreeNode.scala:318)
    at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:74)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:318)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDown$3(TreeNode.scala:323)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$mapChildren$1(TreeNode.scala:408)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:244)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:406)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:359)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:323)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:307)
    at org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReference(BoundAttribute.scala:74)
    at org.apache.spark.sql.catalyst.expressions.BindReferences$.$anonfun$bindReferences$1(BoundAttribute.scala:96)
    at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:238)
    at scala.collection.immutable.List.foreach(List.scala:392)
    at scala.collection.TraversableLike.map(TraversableLike.scala:238)
    at scala.collection.TraversableLike.map$(TraversableLike.scala:231)
    at scala.collection.immutable.List.map(List.scala:298)
    at org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReferences(BoundAttribute.scala:96)
    at org.apache.spark.sql.execution.aggregate.HashAggregateExec.generateResultFunction(HashAggregateExec.scala:554)
    at org.apache.spark.sql.execution.aggregate.HashAggregateExec.doProduceWithKeys(HashAggregateExec.scala:741)
    at org.apache.spark.sql.execution.aggregate.HashAggregateExec.doProduce(HashAggregateExec.scala:148)
    at org.apache.spark.sql.execution.CodegenSupport.$anonfun$produce$1(WholeStageCodegenExec.scala:95)
    at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:218)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:215)
    at org.apache.spark.sql.execution.CodegenSupport.produce(WholeStageCodegenExec.scala:90)
    at org.apache.spark.sql.execution.CodegenSupport.produce$(WholeStageCodegenExec.scala:90)
    at org.apache.spark.sql.execution.aggregate.HashAggregateExec.produce(HashAggregateExec.scala:47)
    at org.apache.spark.sql.execution.WholeStageCodegenExec.doCodeGen(WholeStageCodegenExec.scala:655)
    at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:718)
    at org.apache.spark.sql.execution.SparkPlan.$anonfun$execute$1(SparkPlan.scala:180)
    at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:218)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:215)
    at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:176)
    at org.apache.spark.sql.execution.TakeOrderedAndProjectExec.executeCollect(limit.scala:187)
    at org.apache.spark.sql.Dataset.collectFromPlan(Dataset.scala:3696)
    at org.apache.spark.sql.Dataset.$anonfun$head$1(Dataset.scala:2722)
    at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:3687)
    at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$5(SQLExecution.scala:103)
    at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:163)
    at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:90)
    at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775)
    at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:64)
    at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3685)
    at org.apache.spark.sql.Dataset.head(Dataset.scala:2722)
    at org.apache.spark.sql.Dataset.take(Dataset.scala:2929)
    at org.apache.spark.sql.Dataset.getRows(Dataset.scala:301)
    at org.apache.spark.sql.Dataset.showString(Dataset.scala:338)
    at sun.reflect.GeneratedMethodAccessor62.invoke(Unknown Source)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
    at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
    at py4j.Gateway.invoke(Gateway.java:282)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:238)
    at java.lang.Thread.run(Thread.java:748)
Caused by: java.lang.RuntimeException: Couldn't find first(description)#13046 in [prerequisite#5422,count(1)#13045L]
    at scala.sys.package$.error(package.scala:30)
    at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.$anonfun$applyOrElse$1(BoundAttribute.scala:81)
    at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52)
    ... 61 more

CodePudding user response:

Have you ever tried the writing method below?

`SELECT
    prerequisite AS prerequisite,
    FIRST((SELECT e.description FROM course e WHERE e.course_no = c.prerequisite)) as course_name,
    COUNT(*) as cnt
FROM
    course c
WHERE 
    c.prerequisite IS NOT NULL
GROUP BY 
    c.prerequisite
ORDER BY
    prerequisite`

I think it is okay...

CodePudding user response:

I think you can simply use join instead of correlated subquery:

SELECT
  c.prerequisite,
  FIRST(e.description) AS course_name,
  COUNT(*) AS cnt
FROM
  course c
LEFT JOIN
  course e
ON
  e.course_no = c.prerequisite
WHERE
  c.prerequisite IS NOT NULL
GROUP BY
  c.prerequisite
ORDER BY
  1
  • Related