Home > Blockchain >  cython to speed up a 3d list manipulation
cython to speed up a 3d list manipulation

Time:05-24

Someone could help me to create a cython code for this example?

I create this example because I would like to create a faster version of it, as a solution I was thinking about and/or , but if there are other way, please feel free to propose them.

The code aims to generate a 3d list(called list2) starting from a previous one(list1 randomly generated).

It acts with two transformation function(transform and trasform1): one that perform only some random math(transform) and the other that use the list1 to generate the final value to insert in list2.

I track the time with

The code is the following:

import math
from random import seed
from random import random

global i_elements
global j_elements
global k_elements

i_elements,j_elements,k_elements =11,11,11

global list1
global list2
seed(1)
list1 = [[[random() for k in range(i_elements)] for j in range(j_elements)] for i in range(k_elements)]
list2 = [[[0 for k in range(i_elements)] for j in range(j_elements)] for i in range(k_elements)]

def transform(x, y, z):
    '''
    no-sense function performing some math
    '''
    a = 0
    b = 0
    sol = 0

    if x > y:
        a = math.sqrt(x ** 2)   math.atan(y ** 2  1)
    else:
        b = math.sqrt(z ** 2)   math.atan(y ** 2  1)

    if x > z:
        sol = math.sqrt(a*b)
    else:
        sol = math.sqrt(b**2)

    return sol

def transform2(a, b, c):
    '''
    transformation dependent on element in list1
    '''

    global list1, i_elements,j_elements,k_elements
    sol = 0

    for i in range(i_elements):
        for j in range(j_elements):
            for k in range(k_elements):
                temp = transform(i, j, k)
                if list1[i][j][k] > temp:
                    sol = temp*list1[i][j][k]*(a 1)**2
                else:
                    sol = temp   list1[i][j][k]**(b*c  1)

    return sol

def save_list():
    '''
    function to save my 3d list after the transform2
    '''
    global list2,i_elements, j_elements, k_elements
    for i in range(i_elements):
        for j in range(j_elements):
            for k in range(k_elements):
                list2[i][j][k] = transform2(i,j,k)

    return list2

def main():
    save_list()
    print('finish')

if __name__ == "__main__":
    import cProfile

    cProfile.run('main()')

The output is:

finish
         7087581 function calls in 2.628 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    2.628    2.628 <string>:1(<module>)
  1771561    1.453    0.000    1.932    0.000 code_to_speed_up.py:17(transform)
     1331    0.696    0.001    2.628    0.002 code_to_speed_up.py:37(transform2)
        1    0.001    0.001    2.628    2.628 code_to_speed_up.py:56(save_list)
        1    0.000    0.000    2.628    2.628 code_to_speed_up.py:68(main)
        1    0.000    0.000    2.628    2.628 {built-in method builtins.exec}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.print}
  1771561    0.178    0.000    0.178    0.000 {built-in method math.atan}
  3543122    0.301    0.000    0.301    0.000 {built-in method math.sqrt}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}

CodePudding user response:

Numpy can be used for this problem but the transform2 can hardly be efficiently vectorized. However, Cython and Numba can do that efficiently (Numba is a bit like Cython, but it is a just in time compiler and it is simpler to use here).

Using Cython or Numba alone is not enough since lists cannot be computed efficiently (due to reference counting, the GIL, an inefficient internal representation causing additional indirections, etc.).

Note that using globals is not a good idea. It is generally slower and it makes your code harder to optimize and to understand (this is seen as a bad software engineering practice, especially when globals are modified, since it causes hidden dependencies or said differently a kind of spooky action at a distance).

Here is an (barely tested) example using Numpy Numba:

import math
import numpy as np
import numba as nb

i_elements,j_elements,k_elements = 11,11,11
np.random.seed(1)
arr1 = np.random.rand(i_elements, j_elements, k_elements)
arr2 = np.zeros((i_elements, j_elements, k_elements))


@nb.njit
def transform(x, y, z):
    '''
    no-sense function performing some math
    '''
    a = 0
    b = 0
    sol = 0

    if x > y:
        a = math.sqrt(x ** 2)   math.atan(y ** 2  1)
    else:
        b = math.sqrt(z ** 2)   math.atan(y ** 2  1)

    if x > z:
        sol = math.sqrt(a*b)
    else:
        sol = math.sqrt(b**2)

    return sol

@nb.njit
def transform2(arr1, a, b, c):
    '''
    transformation dependent on element in arr1
    '''
    sol = 0

    for i in range(arr1.shape[0]):
        for j in range(arr1.shape[1]):
            for k in range(arr1.shape[2]):
                temp = transform(i, j, k)
                if arr1[i,j,k] > temp:
                    sol = temp * arr1[i,j,k] * (a 1)**2
                else:
                    sol = temp   arr1[i,j,k]**(b*c  1)

    return sol

# Giving a signature to Numba helps him to compile the function eagerly
# Read the doc for more information about this.
@nb.njit('(float64[:,:,::1],float64[:,:,::1])')
def save_list(arr1, arr2):
    '''
    function to save my 3d list after the transform2
    '''
    for i in range(arr2.shape[0]):
        for j in range(arr2.shape[1]):
            for k in range(arr2.shape[2]):
                arr2[i,j,k] = transform2(arr1, i,j,k)

    return arr2

def main():
    global arr1
    global arr2
    save_list(arr1, arr2)
    print('finish')

if __name__ == "__main__":
    import cProfile
    cProfile.run('main()')

Note that np.random.rand may produce different results despite the seed being set to the same value because Numpy certainly use a different random number generator implementation.

If you really want to use Cython instead of Numba, then you need to use Numpy memory views. For more information, please read the Cython for NumPy users tutorial of the Cython documentation.

Is more than 500 times faster on my machine.

Note that the return in save_list is not very useful since it is passed in parameter (it was not more useful with globals either). Note also that sol is assigned in transform2 which means only the last iteration matters (compiler can optimize that). It is very suspicious though: it looks like a bug and you certainly want to perform a reduction instead (eg. =), especially because of the initial assignment to 0. Please check the results are correct.

  • Related