Home > Back-end >  Sum only certain values of an array based on a condition pyspark
Sum only certain values of an array based on a condition pyspark

Time:11-18

I would like to create a column that is based on the sum of array values. However, if the sum would exceed a target value, it would only sum the values that create the highest value under or equal to the target. Here is an example:

| Target % | Array            | name      | Total
| ---------| -----------------|-----------|----------
| 4.5      | [1.5,2.5,3.0,2.0]| John      | 4.5
| 3        | [2.5,1.0,0.5,1.0]| Jim       | 3.0
| 5        | [1.0,1.0,1.5,1.0]| Jane      | 4.5

CodePudding user response:

You can try this udf assuming Array is of String type

import ast
def custum_sum(target_,array_):    
  array_ = ast.literal_eval(array_)
  if sum(array_) > target_:
    return float(target_)
  return sum(array_)
  
sum_udf = udf(custum_sum, FloatType())  

df.withColumn("Total",  sum_udf("Target","Array")).show()     

Output:

 ------ -------------------- ----- 
|Target|               Array|Total|
 ------ -------------------- ----- 
|   4.5|[1.5, 2.5, 3.0, 2.0]|  4.5|
|     3|[2.5, 1.0, 0.5, 1.0]|  3.0|
|     5|[1.0, 1.0, 1.5, 1.0]|  4.5|
 ------ -------------------- ----- 

CodePudding user response:

However, if the sum would exceed a target value, it would only sum the values that create the highest value under or equal to the target.

To find the highest sum of values that is <= than the target you'll have to find sums of different combinations of values and then find the suitable value.

Here's an example:

from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from itertools import combinations
from pyspark.sql.types import ArrayType, DoubleType, StringType, StructField, StructType


def find_highest(values, target):
    if not values:
        return None
    values.sort()
    max_value = values[0]
    if max_value > target:
        return None
    if max_value == target:
        return max_value
    for i in range(1, len(values)):
        if values[i] <= target and values[i] > max_value:
            max_value = values[i]
    return max_value


def find_closest_sum(numbers, target):
    target = float(target)
    if sum(numbers) <= target:
        return sum(numbers)
    results = []
    for n in range(1, len(numbers)   1):
        sumlist = [sum(l) for l in combinations(numbers, n)]
        highest = find_highest(sumlist, target)
        if highest:
            results.append(highest)
    return find_highest(results, target)


spark = SparkSession.builder.getOrCreate()
data = [
    {"Target": "4.5", "Array": [1.5, 2.5, 3.0, 2.0]},
    {"Target": "3", "Array": [2.5, 1.0, 0.5, 1.0]},
    {"Target": "5", "Array": [1.0, 1.0, 1.5, 1.0]},
    {"Target": "7", "Array": [5.0, 1.0, 4.0]},
]
schema = StructType(
    [StructField("Target", StringType()), StructField("Array", ArrayType(DoubleType()))]
)
df = spark.createDataFrame(data=data, schema=schema)
df = df.withColumn("Total", F.udf(find_closest_sum)(F.col("Array"), F.col("Target")))

Result:

 ------ -------------------- -----                                              
|Target|Array               |Total|
 ------ -------------------- ----- 
|4.5   |[1.5, 2.5, 3.0, 2.0]|4.5  |
|3     |[2.5, 1.0, 0.5, 1.0]|3.0  |
|5     |[1.0, 1.0, 1.5, 1.0]|4.5  |
|7     |[5.0, 1.0, 4.0]     |6.0  |
 ------ -------------------- ----- 
  • Related