Consider this function
using Distributions
using StatsBase
test_vector = sample([1,2,3], 100000000)
function test_1( test_vector)
rand_vector = randn(3333)
sum = 0.0
for t in 1:1000
if t in test_vector
sum = sum rand_vector[t]
else
sum = sum - rand_vector[t]
end
end
end
I applied @profview
to understand performance and it turns out most of the time is spent on if t in test_vector
. Is there a way to speed up this part of the program? I thought about excluding test_vector
from 1:1000
and run two loops, but this creates memory allocation. Can I get a hint?
P.S. I intend to let the user pass in any test_vector
. I'm using sample
to create a test_vector
just for illustration.
CodePudding user response:
If the vector is large, changing the check for element membership will be faster if you create a Set for the check:
using Distributions
using StatsBase
using BenchmarkTools
test_vector = sample([1,2,3], 1000000)
function test_1(test_vector)
rand_vector = randn(3333)
sum1 = 0.0
for t in 1:1000
if t in test_vector
sum1 = sum1 rand_vector[t]
else
sum1 = sum1 - rand_vector[t]
end
end
return sum1
end
function test_1_set(test_vector)
rand_vector = randn(3333)
test_set = Set(test_vector)
sum2 = 0.0
for t in 1:1000
if t in test_set
sum2 = rand_vector[t]
else
sum2 -= rand_vector[t]
end
end
return sum2
end
@btime test_1(test_vector)
@btime test_1_set(test_vector)
677.818 ms (3 allocations: 26.12 KiB)
8.795 ms (10 allocations: 18.03 MiB)
CodePudding user response:
Use a Set
. They have O(1)
lookup.
CodePudding user response:
Bill's version is the fastest if you don't count the construction of the set:
julia> test_vector = rand(1:3, 10000);
julia> rand_vector = randn(3333);
julia> range = 1:1000;
julia> @btime test_1($(Set(test_vector)), $rand_vector, $range)
3.692 μs (0 allocations: 0 bytes)
14.82533505498519
However, the set creation itself is more costly, especially in terms of memory:
julia> @btime test_1(Set($test_vector), $rand_vector, $range)
52.731 μs (7 allocations: 144.59 KiB)
14.82533505498519
Here's a variant more optimized for memory:
julia> function test_2(xs, ys, range)
range = Set(range)
positive = intersect(range, xs)
negative = setdiff!(range, positive)
return sum(ys[i] for i in positive) - sum(ys[i] for i in negative)
end
test_2 (generic function with 1 method)
julia> @btime test_2($test_vector, $rand_vector, $range)
96.020 μs (11 allocations: 18.98 KiB)
14.825335054985187