I'll try to keep it as brief as possible.
I hava this map: Map<Neuron,Float> connections. It contains a Neuron-Objekt as Key and the weight of the connection as value.
The Neuron class has a method "getOutput" to get the output-Value of the neuron.
What I want to do is to go over every neuron in the map, calculate neuron.getOutput * connections.get(neuron) and sum all of that into one variable together.
Is there an elegant way to do this with Java-Streams? Maybe with reduce? I tried it but couldn't get it to work properly.
inputConnections.keySet().stream().reduce(
0f,
(accumulatedFloat, inputNeuron) -> accumulatedFloat inputConnections.get(inputNeuron),
Float::sum);
I guess the 0f results in everything getting multiplied with 0.
This code seems to work, but I'd like a more elegant solution.
AtomicReference<Float> tmp = new AtomicReference<>(0f);
inputConnections.keySet().forEach(inputNeuron -> {
tmp.updateAndGet(v -> new Float((float) (v inputNeuron.getOutput() * inputConnections.get(inputNeuron))));
});
CodePudding user response:
Your approach using reduce is (almost) correct. It should look like the second code snippet where you multiply the neuron's output
with the value from the map (inputConnections.get(..)
)
inputConnections.entrySet()
.stream()
.reduce(0f,
(result, entry) -> result entry.getKey().getOutput() * entry.getValue(),
Float::sum);
CodePudding user response:
you can also achieve the same with map and sum
inputConnections.entrySet().stream().mapToDouble(entry -> entry.getKey().getOutput() * entry.getValue()).sum()
CodePudding user response:
You can also use parallel streams (be careful: only makes sense with huge datasets). In case you need some statistics additional to the sum
, the collector summarizingDouble
is helpful:
DoubleSummaryStatistics sumStat = connections.entrySet().parallelStream()
.map(entry -> Double.valueOf(entry.getKey().getOutput() * entry.getValue()))
.collect(Collectors.summarizingDouble(Double::doubleValue));
System.out.println(sumStat);
Example output: DoubleSummaryStatistics{count=3, sum=18.000000, min=4.000000, average=6.000000, max=8.000000}