Home > OS >  Why is using list comprehension here so slow?
Why is using list comprehension here so slow?

Time:11-05

I was writing some python and needed to find the indexes of the largest numbers a list. I looked it up and found this:

def find_indexes_1(nums)
    return [index for index, number in enumerate(nums) if number == max(nums)]

After some testing, I isolated this line as the main reason my program was so slow. I decided to rewrite the code without using list comprehension.

def find_indexes_2(nums):
    largest = 0
    indexes = []
    for index, number in enumerate(nums):
        if number > largest:
            indexes.clear()
            indexes.append(index)
            largest = number
        elif number == largest:
            indexes.append(index)
    return indexes

This is much longer and harder to read but I found it to be significantly faster than the first method.

from timeit import default_timer as timer
from random import randint

numbers = []
for i in range(10000):
    numbers.append(randint(10, 55))

start = timer()
a = find_indexes_1(numbers)
end = timer()
time = str(end-start)
print(f'method 1 {time}') # 2.087865 seconds
start = timer()
b = find_indexes_2(numbers)
end = timer()
time = str(end-start)
print(f'method 2 {time}') # 0.001524 seconds

my main question is why is the first method so much slower? I understand that it loops through the list twice but that doesn't justify such a large difference between the two times. Thanks

CodePudding user response:

Because max(nums) that is going over nums linearly is called len(nums) times. This gives you O(n2) complexity in find_indexes_1, compared to O(n) in find_indexes_2

You could optimize it like this:

def find_indexes_3(nums)
    max_nums = max(nums) # compute this value once, takes O(n) time
    return [index for index, number in enumerate(nums) if number == max_nums] 

CodePudding user response:

The problem is that you're calling max(nums) on each iteration of the list comprehension, and this has to iterate over the list again. So you have an O(n2) algorithm. Since it doesn't change, you should call it once.

def find_indexes_1(nums)
    maxnum = max(nums)
    return [index for index, number in enumerate(nums) if number == maxnums]
  • Related