Home > Software design >  Pyspark higher order functions - sum 2 values in array of structs at once?
Pyspark higher order functions - sum 2 values in array of structs at once?

Time:09-17

I have a spark dataframe where one column is an array of structs.

df = spark.createDataFrame([(1, [{"col1": 10, "col2": 0.8, "col3": 0.1}, {"col1": 9, "col2": 1.8, "col3":0.0}, {"col1": 8, "col2": 1.9, "col3": None}])], ['rowNum', 'vals'])

I am trying to create an aggregate function that sums all of the col2s and divides it by all of the col1s you can ignore col3. I know how to create an aggregate function if I was just summing one of the columns (using higher order functions in pyspark 2.4 ), but is it possible to sum 2 items at the same time, or do I have to do it as 2 separate steps.

If I have to do it as 2 separate steps, then I can do:

df = df.withColumn("sum", F.aggregate("vals", F.lit(0.0), lambda x, y: x   y.col2)) \
.withColumn("denom", F.aggregate("vals", F.lit(0.0), lambda x, y: x   y.col1)) \
.withColumn("output", F.col('sum')/F.col('denom'))

I was wondering if there is a higher order function that does this in a cleaner way/in 1 step? Many thanks.

CodePudding user response:

You can create an array of two elements to aggregate sum and denom. Also, AGGREGATE or REDUCE have this signature: reduce(array<T>, B, function<B, T, B>, function<B, R>): R where function<B, R> you can apply another function over your aggregation and it's exactly what I do at the end to divide sum by denom.

import pyspark.sql.functions as F

df = spark.createDataFrame([(1, [
  {"col1": 10.0, "col2": 0.8, "col3": 0.1}, 
  {"col1": 9.0, "col2": 1.8, "col3":0.0}, 
  {"col1": 8.0, "col2": 1.9, "col3": None}])], schema='rowNum int, vals array<struct<`col1`:double, `col2`:double, `col3`:double>>')

expr = ('AGGREGATE(vals, ARRAY(CAST(0.0 AS DOUBLE), CAST(0.0 AS DOUBLE)), (acc, el) -> '
        'ARRAY(acc[0]   el.col2, acc[1]   el.col1), acc -> acc[0] / acc[1])')
df.withColumn('output', F.expr(expr)).show()

 ------ -------------------- ------------------- 
|rowNum|                vals|             output|
 ------ -------------------- ------------------- 
|     1|[{10.0, 0.8, 0.1}...|0.16666666666666666|
 ------ -------------------- ------------------- 
  • Related