Home > database >  Expand expression in Spark Scala aggregation
Expand expression in Spark Scala aggregation

Time:06-30

I'm trying to convert a simple aggregation code from PySpark to Scala.

The dataframes:

# PySpark
from pyspark.sql import functions as F
df = spark.createDataFrame(
    [([10, 100],),
     ([20, 200],)],
    ['vals'])
// Scala
val df = Seq(
    (Seq(10, 100)),
    (Seq(20, 200)),
).toDF("vals")

Aggregation expansion - OK in PySpark:

df2 = df.agg(
    *[F.sum(F.col("vals")[i]).alias(f"col{i}") for i in range(2)]
)
df2.show()
#  ---- ---- 
# |col0|col1|
#  ---- ---- 
# |  30| 300|
#  ---- ---- 

But in Scala...

val df2 = df.agg(
  (0 until 2).map(i => sum($"vals"(i)).alias(s"col$i")): _*
)
         (0 until 2).map(i => sum($"vals"(i)).alias(s"col$i")): _*
                                                              ^
On line 2: error: no `: _*` annotation allowed here
       (such annotations are only allowed in arguments to *-parameters)

The syntax seems almost the same to this select which works well:

val df2 = df.select(
  (0 until 2).map(i => $"vals"(i).alias(s"col$i")): _*
)

Does expression expansion work in Scala Spark aggregations? How?

CodePudding user response:

If you look at the documentation of Dataset.agg, you see that it first has a fixed parameter and then a list of unspecified length:

def agg(expr: Column, exprs: Column*): DataFrame 

So you should first have any other aggregation, than for the second argument you can do the list expansion. So something like

val df2 = df.agg(
  first($"vals"), (0 until 2).map(i => sum($"vals"(i)).alias(s"col$i")): _*
)

or any other single aggregation in front of the list should work. I don't know why it is like this, maybe it's a Scala limitation so you can't pass an empty list and have no aggregation at all?

CodePudding user response:

i'm not fully understanding why this is happening for the compiler but it seems that it is not unpacking your Seq[Column] to vararg as params.

as @RvdV has mentioned in his post, the signature of the method is def agg(expr: Column, exprs: Column*): DataFrame

so a temp solution is you unpack it manually, like:

val seq = Seq(0, 1).map(i => sum($"vals"(i)).alias(s"col$i"))
val df2 = df.agg(seq(0), seq(1))
  • Related