Home > database >  How can I expand a huge array into several columns in spark?
How can I expand a huge array into several columns in spark?

Time:06-30

I am trying to average each element of a column of arrays by index on a group by, so that I can start with a dataframe like this:

-----------------------
id     | weights
-----------------------
1      | [ 34, 23, 56 ]
1      | [  5, 45, 10 ]
1      | [ 38, 30, 50 ]
2      | [ 45,  5, 20 ]
2      | [ 40, 11, 23 ]

Then groupby "id" and somehow have an array with an average of the weights per index:

-----------------------
id     | weights
-----------------------
1      | [ 25.667, 32.667, 38.667 ]
2      | [   42.5,      8,   21.5 ]

I know that I could do this by splitting "weights" into separate columns and then doing groupby. The only issue is that the above is just an example, and the array in my real weights column has 300 elements. I've been able to see that I can get a few elements split up with this code:

sample_output.select($"weights".getItem(0).as("First"),
    $"weights".getItem(1).as("Second"),
    $"weights".getItem(2).as("Third"))
    .show()

However, I don't want to have to write that out for 300 elements. I next tried creating a list of strings for all of my columns and selecting that from my df:

val dimNums = (0 until typedConfig.embeddingDims).toList
val all_columns = dimNums.map(x => "$\"weights\".getItem(%d).as(\"dim%d\")".format(x,x))
sample_output.select(all_columns.head, all_columns.tail: _*)
        .show(false)

This gave me the error:

org.apache.spark.sql.AnalysisException: cannot resolve '`$"weights".getItem(0).as("dim0")`' given input columns: [id, WordToken, weights];;

I also tried:

sample_output.select(all_columns.map(col): _*).show(false)

But I got the exact same error. Does anyone know a way to split up a huge array like this into several columns? Or else is there another way that I can average these lists by index while using groupby? Please let me know if I've left anything out, I'm still learning Spark and Scala

CodePudding user response:

Spark doesn't have array aggregate functions that you can use with the agg method of relationally grouped datasets, so you either need to write your own user-defined aggregator or get creative with the available array functions.

Provided you do not have too many entries with the same id, you can first aggregate the groups by collecting their weights arrays into one big nested array:

val grouped = df.groupBy("id").agg(collect_list($"weights") as "collected")

grouped.show(false)
// --- ----------------------------------------- 
//|id |collected                                |
// --- ----------------------------------------- 
//|1  |[[34, 23, 56], [5, 45, 10], [38, 30, 50]]|
//|2  |[[45, 5, 20], [40, 11, 23]]              |
// --- ----------------------------------------- 

You can now sum element-wise the arrays inside collected by using aggregate, the badly named equivalent of foldLeft for arrays in Spark. It takes an array column, an initial value column, and a merge function. The initial value has to be an array of 0s with the same length as the sub-arrays. In order not to hardcode anything, you can use size on the first element of the gathered array:

val initialCol = array_repeat(lit(0), size(element_at($"collected", 1)))

Summing up the arrays boils down to applying element-wise using zip_with:

val summed = grouped.withColumn(
  "sums",
  aggregate($"collected", initialCol, (acc, x) => zip_with(acc, x, _   _))
)

summed.show(false)
// --- ----------------------------------------- ------------- 
//|id |collected                                |sums         |
// --- ----------------------------------------- ------------- 
//|1  |[[34, 23, 56], [5, 45, 10], [38, 30, 50]]|[77, 98, 116]|
//|2  |[[45, 5, 20], [40, 11, 23]]              |[85, 16, 43] |
// --- ----------------------------------------- ------------- 

All you have to do now is to divide each element in sums by the size of collected using transform, the equivalent of map for arrays:

val avged = summed.select(
  $"id",
  transform($"sums", _ / size($"collected")) as "weights"
)

avged.show(false)
// --- ------------------------------------------------------------ 
//|id |weights                                                     |
// --- ------------------------------------------------------------ 
//|1  |[25.666666666666668, 32.666666666666664, 38.666666666666664]|
//|2  |[42.5, 8.0, 21.5]                                           |
// --- ------------------------------------------------------------ 

Or, if you prefer, an optimised single-expression version of the above logic which uses aggregate functions:

val avged = df
  .groupBy("id")
  .agg(
    transform(
      aggregate(
        collect_list($"weights"),
        array_repeat(lit(0), size(first($"weights"))),
        (acc, x) => zip_with(acc, x, _   _)
      ),
      _ / count($"weights")
    ) as "weights"
  )

avged.show(false)
// --- ------------------------------------------------------------ 
//|id |weights                                                     |
// --- ------------------------------------------------------------ 
//|1  |[25.666666666666668, 32.666666666666664, 38.666666666666664]|
//|2  |[42.5, 8.0, 21.5]                                           |
// --- ------------------------------------------------------------ 

All array functions are in the org.apache.spark.sql.functions object and most of them are only available in Spark 3.x. If you are using Spark 2.x, you are out of luck and need to write your own aggregator.

CodePudding user response:

import org.apache.spark.sql.functions.col
import spark.implicits._

val source = Seq(
  (1, List(34, 23, 56)),
  (1, List(5, 45, 10)),
  (1, List(38, 30, 50)),
  (2, List(45, 5, 20)),
  (2, List(40, 11, 23))
).toDF("id", "weights")

source.printSchema
//    root
//    |-- id: integer (nullable = false)
//    |-- weights: array (nullable = true)
//    |    |-- element: integer (containsNull = false)

def getColAtIndex(id:Int) = col(s"weights")(id).as(s"weights_${id 1}")

val columns = (0 to 2).map(getColAtIndex) :  col("id")  //Here, instead of 2, you can give the value of n (for you 300)

source.select(columns: _*).show(false)
//     --------- --------- --------- --- 
//    |weights_1|weights_2|weights_3|id |
//     --------- --------- --------- --- 
//    |34       |23       |56       |1  |
//    |5        |45       |10       |1  |
//    |38       |30       |50       |1  |
//    |45       |5        |20       |2  |
//    |40       |11       |23       |2  |
//     --------- --------- --------- --- 
  • Related