Home > Back-end >  Calculate weighted average from objects with Java stream
Calculate weighted average from objects with Java stream

Time:01-07

Given a simple Measurement class:

@Getter
@AllArgsConstructor
public class Measurement {
    private BigDecimal rawValue;
    private BigDecimal weighting;
}

The service processes a List<Measurement> and shall calculate a weighted average by applying the following steps:

  1. Multiply each rawValue with its weighting.
  2. Build the sum of all weighted raw values from (1).
  3. Divide the sum (2) by the sum of all weightings.

Example

(1.0 * 5.0) (2.0 * 2.0) (3.0 * 7.0) / 14

Here is my current implementation:

public class MeasurementService {

    private List<Measurement> measurements = new ArrayList<>();

    public MeasurementService() {
        Measurement m1 = new Measurement(new BigDecimal("1.0"), new BigDecimal("5.0"));
        Measurement m2 = new Measurement(new BigDecimal("2.0"), new BigDecimal("2.0"));
        Measurement m3 = new Measurement(new BigDecimal("3.0"), new BigDecimal("7.0"));
        measurements.add(m1);
        measurements.add(m2);
        measurements.add(m3);
    }

    public BigDecimal calculateWeightedAverage() {
        final BigDecimal totalWeighting = measurements.stream()
                .map(measurement -> measurement.getWeighting())
                .reduce(BigDecimal.ZERO, BigDecimal::add);
        return measurements.stream()
                .map(measurement -> measurement.getRawValue().multiply(measurement.getWeighting()))
                .reduce(BigDecimal.ZERO, BigDecimal::add)
                .divide(totalWeighting, 3, RoundingMode.DOWN);
    }
}

While this is working, I don't like using the stream of measurements twice and I bet there is a better way. Unfortunately, I couldn't come up with one so far.

Do you have any idea how to simplify the calculateWeightedAverage method?

I'm using Java 17.

CodePudding user response:

Map to a new class and reduce that, then use the reduced value to compute the final result. Something like:

@Data
@RequiredArgsConstructor
public class ResultAccumulator {
    public static final ResultAccumulator ZERO = new ResultAccumulator(BigDecimal.ZERO, BigDecimal.ZERO);

    private final BigDecimal weightedValue;
    private final BigDecimal weightSum;

    public static ResultAccumulator fromMeasurement(final Measurement measurement) {
        return new ResultAccumulator(
                measurement.getRawValue().multiply(measurement.getWeighting()),
                measurement.getWeighting());
    }

    public ResultAccumulator merge(final ResultAccumulator other) {
        return new ResultAccumulator(
                getWeightedValue().add(other.getWeightedValue()),
                getWeightSum().add(other.getWeightSum()));
    }

    public BigDecimal getFinalResult() {
        return getWeightedValue().divide(getWeightSum(), 3, RoundingMode.DOWN);
    }
}

public BigDecimal calculateWeightedAverage() {
    final ResultAccumulator result = measurements.stream()
            .map(ResultAccumulator::fromMeasurement)
            .reduce(
                    ResultAccumulator.ZERO,
                    ResultAccumulator::merge);
    return result.getFinalResult();
}

Or do it quick and dirty and reduce your original Measurement by storing the "weighted value" in rawValue and the sum of weights in weight. The structure of both classes is the same, but using two classes is less confusing to others and yourself in 2 weeks.

CodePudding user response:

Assuming you are using Java 12 or higher, to avoid to have to stream over the list twice, you can use the teeing collector. In addition to the docs here is an article about it the-teeing-collector

public BigDecimal calculateWeightedAverage() {
    return measurements.stream()
                .collect(Collectors.teeing(
                        Collectors.reducing(BigDecimal.ZERO, m -> m.getRawValue().multiply(m.getWeighting()), BigDecimal::add),
                        Collectors.reducing(BigDecimal.ZERO, Measurement::getWeighting, BigDecimal::add ),
                        (weightedRawValues , totalWeighting) -> weightedRawValues.divide(totalWeighting, 3, RoundingMode.DOWN)));
}
  • Related