Home > Back-end >  How to make this function faster: algorithm to create chain based on even and odd numbers
How to make this function faster: algorithm to create chain based on even and odd numbers

Time:05-25

in Python I would like to write a function that returns the number from a number set that leads to the longest sequence created based on a calculation algorithm ("if n is even divide by two else multiply by three and add one"). F.e. for the number 3, the sequence would be 3-> 10 -> 5 -> 16 -> 8 -> 4 -> 2 -> 1.

My attempt of writing a fast, non naive function:

def my_fun(n):
    dici = {k: v for k, v in zip(range(n), range(n))}
    count = 0
    while (len(dici)>1):  
       count  = 1        
       dici = {k:(v/2 if v%2 == 0 else v*3 1) for k,v in {key: dici[key] for key in dici if dici[key] > 1}.items() }
    while all(value != 1 for value in dici.values()):
       dici = {k:(v/2) for k, v in dici.items()}
       count  =1
    return dici, count

With the input of 10, the function returns 9 as the number leading to the longest chain (length 19).

As I would like to apply my function to greater numbers, this version is way too computationally extensive. Is there any way to shorten my code?

CodePudding user response:

I don't think you need the complexity of a dictionary here. You may find this faster:

def func(n):
    count = 0
    while n != 1:
        if n % 2 == 0:
            n //= 2
        else:
            n = n * 3   1
        count  = 1
    return count

def my_fun(n):
    if m := [(func(n_), n_) for n_ in range(n, 0, -1)]:
        return max(m)
    return -1, -1
    
print(my_fun(10))

Output:

(19, 9)

CodePudding user response:

  • floating point arithmetics & comparison for equality is dangerous: use //
    (got a non-ending loop for 7)
  • just dividing by two when just one start value is left may be incorrect
  • For every argument s smaller than n//2, there is one twice as large taking one step more:
    Start with
def ulam(n):
    """ return start, length of the longest sequence starting lower than n. """
    dici = {k: v for k, v in zip(range(n//2, n), range(n//2, n))}
    count = 0
    while dici:  # len(dici)>1):  
        count  = 1
        start = next(iter(dici))
        dici = {k:(v//2 if v%2 == 0 else v*3 1) 
                    for k,v in dici.items() if v > 2 }  # > 1 }
        # print(dici, start)
    return start, count

  • Related