Home > Enterprise >  Repeated membership tests in a loop
Repeated membership tests in a loop

Time:12-02

Consider this function

using Distributions
using StatsBase
test_vector = sample([1,2,3], 100000000)
function test_1( test_vector)
    rand_vector = randn(3333)
    
    sum = 0.0
    for t in 1:1000
        if t in test_vector
            sum = sum  rand_vector[t]
        else 
            sum = sum - rand_vector[t]
        end
    end
end

I applied @profview to understand performance and it turns out most of the time is spent on if t in test_vector. Is there a way to speed up this part of the program? I thought about excluding test_vector from 1:1000 and run two loops, but this creates memory allocation. Can I get a hint?

P.S. I intend to let the user pass in any test_vector. I'm using sample to create a test_vector just for illustration.

CodePudding user response:

If the vector is large, changing the check for element membership will be faster if you create a Set for the check:

using Distributions
using StatsBase
using BenchmarkTools

test_vector = sample([1,2,3], 1000000)

function test_1(test_vector)
    rand_vector = randn(3333)
    
    sum1 = 0.0
    for t in 1:1000
        if t in test_vector
            sum1 = sum1   rand_vector[t]
        else 
            sum1 = sum1 - rand_vector[t]
        end
    end
    return sum1
end

function test_1_set(test_vector)
rand_vector = randn(3333)

test_set = Set(test_vector)
sum2 = 0.0
for t in 1:1000
    if t in test_set
        sum2  = rand_vector[t]
    else 
        sum2 -= rand_vector[t]
    end
end
    return sum2
end

@btime test_1(test_vector)
@btime test_1_set(test_vector)

677.818 ms (3 allocations: 26.12 KiB)
8.795 ms (10 allocations: 18.03 MiB)

CodePudding user response:

Use a Set. They have O(1) lookup.

CodePudding user response:

Bill's version is the fastest if you don't count the construction of the set:

julia> test_vector = rand(1:3, 10000);

julia> rand_vector = randn(3333);

julia> range = 1:1000;

julia> @btime test_1($(Set(test_vector)), $rand_vector, $range)
  3.692 μs (0 allocations: 0 bytes)
14.82533505498519

However, the set creation itself is more costly, especially in terms of memory:

julia> @btime test_1(Set($test_vector), $rand_vector, $range)
  52.731 μs (7 allocations: 144.59 KiB)
14.82533505498519

Here's a variant more optimized for memory:

julia> function test_2(xs, ys, range)
           range = Set(range)
           positive = intersect(range, xs)
           negative = setdiff!(range, positive)
           return sum(ys[i] for i in positive) - sum(ys[i] for i in negative)
       end
test_2 (generic function with 1 method)

julia> @btime test_2($test_vector, $rand_vector, $range)
  96.020 μs (11 allocations: 18.98 KiB)
14.825335054985187
  • Related