I am trying to average each element of a column of arrays by index on a group by, so that I can start with a dataframe like this:
-----------------------
id | weights
-----------------------
1 | [ 34, 23, 56 ]
1 | [ 5, 45, 10 ]
1 | [ 38, 30, 50 ]
2 | [ 45, 5, 20 ]
2 | [ 40, 11, 23 ]
Then groupby "id" and somehow have an array with an average of the weights per index:
-----------------------
id | weights
-----------------------
1 | [ 25.667, 32.667, 38.667 ]
2 | [ 42.5, 8, 21.5 ]
I know that I could do this by splitting "weights" into separate columns and then doing groupby. The only issue is that the above is just an example, and the array in my real weights column has 300 elements. I've been able to see that I can get a few elements split up with this code:
sample_output.select($"weights".getItem(0).as("First"),
$"weights".getItem(1).as("Second"),
$"weights".getItem(2).as("Third"))
.show()
However, I don't want to have to write that out for 300 elements. I next tried creating a list of strings for all of my columns and selecting that from my df:
val dimNums = (0 until typedConfig.embeddingDims).toList
val all_columns = dimNums.map(x => "$\"weights\".getItem(%d).as(\"dim%d\")".format(x,x))
sample_output.select(all_columns.head, all_columns.tail: _*)
.show(false)
This gave me the error:
org.apache.spark.sql.AnalysisException: cannot resolve '
$"weights".getItem(0).as("dim0")
' given input columns: [id, WordToken, weights];;
I also tried:
sample_output.select(all_columns.map(col): _*).show(false)
But I got the exact same error. Does anyone know a way to split up a huge array like this into several columns? Or else is there another way that I can average these lists by index while using groupby?
CodePudding user response:
Spark doesn't have array aggregate functions that you can use with the agg
method of relationally grouped datasets, so you either need to write your own user-defined aggregator or get creative with the available array functions.
Provided you do not have too many entries with the same id
, you can first aggregate the groups by collecting their weights
arrays into one big nested array:
val grouped = df.groupBy("id").agg(collect_list($"weights") as "collected")
grouped.show(false)
// --- -----------------------------------------
//|id |collected |
// --- -----------------------------------------
//|1 |[[34, 23, 56], [5, 45, 10], [38, 30, 50]]|
//|2 |[[45, 5, 20], [40, 11, 23]] |
// --- -----------------------------------------
You can now sum element-wise the arrays inside collected
by using aggregate
, the badly named equivalent of foldLeft
for arrays in Spark. It takes an array column, an initial value column, and a merge function. The initial value has to be an array of 0s with the same length as the sub-arrays. In order not to hardcode anything, you can use size
on the first element of the gathered array:
val initialCol = array_repeat(lit(0), size(element_at($"collected", 1)))
Summing up the arrays boils down to applying
element-wise using zip_with
:
val summed = grouped.withColumn(
"sums",
aggregate($"collected", initialCol, (acc, x) => zip_with(acc, x, _ _))
)
summed.show(false)
// --- ----------------------------------------- -------------
//|id |collected |sums |
// --- ----------------------------------------- -------------
//|1 |[[34, 23, 56], [5, 45, 10], [38, 30, 50]]|[77, 98, 116]|
//|2 |[[45, 5, 20], [40, 11, 23]] |[85, 16, 43] |
// --- ----------------------------------------- -------------
All you have to do now is to divide each element in sums
by the size of collected
using transform
, the equivalent of map
for arrays:
val avged = summed.select(
$"id",
transform($"sums", _ / size($"collected")) as "weights"
)
avged.show(false)
// --- ------------------------------------------------------------
//|id |weights |
// --- ------------------------------------------------------------
//|1 |[25.666666666666668, 32.666666666666664, 38.666666666666664]|
//|2 |[42.5, 8.0, 21.5] |
// --- ------------------------------------------------------------
Or, if you prefer, an optimised single-expression version of the above logic which uses aggregate functions:
val avged = df
.groupBy("id")
.agg(
transform(
aggregate(
collect_list($"weights"),
array_repeat(lit(0), size(first($"weights"))),
(acc, x) => zip_with(acc, x, _ _)
),
_ / count($"weights")
) as "weights"
)
avged.show(false)
// --- ------------------------------------------------------------
//|id |weights |
// --- ------------------------------------------------------------
//|1 |[25.666666666666668, 32.666666666666664, 38.666666666666664]|
//|2 |[42.5, 8.0, 21.5] |
// --- ------------------------------------------------------------
All array functions are in the org.apache.spark.sql.functions
object and most of them are only available in Spark 3.x. If you are using Spark 2.x, you are out of luck and need to write your own aggregator.
CodePudding user response:
import org.apache.spark.sql.functions.col
import spark.implicits._
val source = Seq(
(1, List(34, 23, 56)),
(1, List(5, 45, 10)),
(1, List(38, 30, 50)),
(2, List(45, 5, 20)),
(2, List(40, 11, 23))
).toDF("id", "weights")
source.printSchema
// root
// |-- id: integer (nullable = false)
// |-- weights: array (nullable = true)
// | |-- element: integer (containsNull = false)
def getColAtIndex(id:Int) = col(s"weights")(id).as(s"weights_${id 1}")
val columns = (0 to 2).map(getColAtIndex) : col("id") //Here, instead of 2, you can give the value of n (for you 300)
source.select(columns: _*).show(false)
// --------- --------- --------- ---
// |weights_1|weights_2|weights_3|id |
// --------- --------- --------- ---
// |34 |23 |56 |1 |
// |5 |45 |10 |1 |
// |38 |30 |50 |1 |
// |45 |5 |20 |2 |
// |40 |11 |23 |2 |
// --------- --------- --------- ---
CodePudding user response:
Input dataframe:
val df = Seq(
(1, Seq(34, 23, 56)),
(1, Seq( 5, 45, 10)),
(1, Seq(38, 30, 50)),
(2, Seq(45, 5, 20)),
(2, Seq(40, 11, 23)),
).toDF("id", "weights")
Script:
val df2 = df.groupBy("id").agg(
first($"id"), (0 to 2).map(i => avg($"weights"(i)).as(s"w$i")): _*
)
val df3 = df2.select($"id", array("w0", "w1", "w2").as("weights"))
df3.show(truncate=false)
// --- ------------------------------------------------------------
// |id |weights |
// --- ------------------------------------------------------------
// |1 |[25.666666666666668, 32.666666666666664, 38.666666666666664]|
// |2 |[42.5, 8.0, 21.5] |
// --- ------------------------------------------------------------