Home > Back-end >  Why is this binary search optimization much slower?
Why is this binary search optimization much slower?

Time:07-09

A supposed optimization made the code over twice as slow.

I counted how often a value x occurs in a sorted list a by finding the range where it occurs:

from bisect import bisect_left, bisect_right

def count(a, x):
    start = bisect_left(a, x)
    stop = bisect_right(a, x)
    return stop - start

But hey, it can't stop before it starts, so we can optimize the second search by leaving out the part before start (doc):

def count(a, x):
    start = bisect_left(a, x)
    stop = bisect_right(a, x, start)
    return stop - start

But when I benchmarked, the optimized version took over twice as long:

 254 ms ±  1 ms  original
 525 ms ±  2 ms  optimized

Why?

The benchmark builds a sorted list of ten million random ints from 0 to 99999, and then counts all different ints (just for benchmarking, no use to point out Counter) (Try it online!):

import random
from bisect import bisect_left, bisect_right
from timeit import repeat
from statistics import mean, stdev

def original(a, x):
    start = bisect_left(a, x)
    stop = bisect_right(a, x)
    return stop - start

def optimized(a, x):
    start = bisect_left(a, x)
    stop = bisect_right(a, x, start)
    return stop - start

a = sorted(random.choices(range(100_000), k=10_000_000))
unique = set(a)

def count_all():
    for x in unique:
        count(a, x)
for count in original, optimized:
    times = repeat(count_all, number=1)
    ts = [t * 1e3 for t in sorted(times)[:3]]
    print(f'{round(mean(ts)):4} ms ± {round(stdev(ts)):2} ms ', count.__name__)

CodePudding user response:

There are a couple respects in which the benchmark triggers adverse cache effects.

First, I bet this assert will pass for you (as it does for me):

assert list(unique) == sorted(unique)

There's no guarantee that will pass, but given the implementations of CPython's set type and integer hashing to date, it's likely to pass.

That implies your for x in unique is trying x in strictly increasing order. That makes the potential probe sequences inside bisect_left() much the same from one x to the next, so many of the values being compared are likely sitting in cache. The same is true of bisect_right() in the original, but in the optimized version the potential probe sequences for bisect_right() differ across tries because the start index differs across tries.

To make both versions slow down "a lot", add this after the assert:

unique = list(unique)
random.shuffle(unique)

Now there's no regularity in the input x across tries, so no systemic correlation either in potential probe sequences across tries.

The other cache effects come within a single try. In the original, the potential probe sequences are exactly the same between bisect_left() and bisect_right(). Entries read up to resolve bisect_left() are very likely still sitting in cache for bisect_right() to reuse.

But in the optimized version, the potential probe sequences differ because the slice bounds differ. For example, bisect_left() will always start by comparing x to a[5000000]. In the original, bisect_right() will also always start by making that same compare, but in the optimized version will almost always pick a different index of a to start with - and that one will be waiting in cache purely by luck.

All that said, I usually use your optimization in my own code. But that's because I typically have comparison operations far more expensive than integer compares, and saving a compare is worth a lot more than saving some cache misses. Comparison of small ints is very cheap, so saving some of those is worth little.

CodePudding user response:

I tried simulating caching in Python and measured cache misses for various cache sizes:

        ORIGINAL:              OPTIMIZED:
cache |         cache-misses |         cache-misses |
 size |  time   line   item  |  time   line   item  |  
------ ---------------------- ---------------------- 
 1024 | 1.98 s  59.1%  16.4% | 4.90 s  74.4%  57.1% |
 2048 | 2.30 s  59.1%  16.4% | 5.28 s  72.5%  56.8% |
 4096 | 2.16 s  59.0%  16.4% | 5.30 s  70.4%  56.4% |
 8192 | 2.33 s  59.0%  16.4% | 6.09 s  68.2%  56.0% |
16384 | 2.80 s  59.0%  16.4% | 6.30 s  65.8%  55.6% |

I used a proxy object for the list. Getting a list item goes through the getitem function, which has an LRU cache of the size shown in the leftmost column. And getitem doesn't access the list directly, either. It goes through the getline function, which fetches a "cache line", a block of 8 consecutive list elements. It has an LRU cache of the cache size divided by 8.

It's far from perfect, i.e., from the real thing, measuring the real cache misses, especially since it only simulates caching the references in the list but not the list element objects. But I find it interesting nonetheless. The original version of my function shows fewer cache misses, and the miss rates appear to be pretty constant throughout the various cache sizes. The optimized version shows more cache misses, and larger cache size helps to reduce the miss rates.

My code (Try it online!):

import random
from bisect import bisect_left, bisect_right
from timeit import timeit
from statistics import mean, stdev
from functools import lru_cache

def original(a, x):
    start = bisect_left(a, x)
    stop = bisect_right(a, x)
    return stop - start

def optimized(a, x):
    start = bisect_left(a, x)
    stop = bisect_right(a, x, start)
    return stop - start

a = sorted(random.choices(range(100_000), k=10_000_000))
unique = set(a)

class Proxy:
    __len__ = a.__len__
    def __getitem__(self, index):
        return getitem(index)
p = Proxy()

def count_all():
    for x in unique:
        count(p, x)

linesize = 8

print('''        ORIGINAL:              OPTIMIZED:
cache |         cache misses |         cache misses |
 size |  time   line   item  |  time   line   item  |  
------ ---------------------- ---------------------- ''')

for cachesize in 1024, 2048, 4096, 8192, 16384:
    print(f'{cachesize:5} |', end='')

    @lru_cache(cachesize // linesize)
    def getline(i):
        i *= linesize
        return a[i : i linesize]
    
    @lru_cache(cachesize)
    def getitem(index):
        q, r = divmod(index, linesize)
        return getline(q)[r]
    
    for count in original, optimized:
        getline.cache_clear()
        getitem.cache_clear()
        time = timeit(count_all, number=1)
        def misses(func):
            ci = func.cache_info()
            misses = ci.misses / (ci.misses   ci.hits)
            return f'{misses:.1%}'
        print(f'{time:5.2f} s  {misses(getline)}  {misses(getitem)}', end=' |')
    print()
  • Related