Home > Software engineering >  Python numba returned data types when calculating MSE
Python numba returned data types when calculating MSE

Time:08-16

I am using numba to calculate MSE. The input are images which are ready as numpy arrays of uint8. Each element is 0-255. When calculating the squared difference between two images the python function returns (expectedly) a uint8 result, but the same function when using numba returns int64.

@numba.jit(nopython=True)
def test1(var_a: np.ndarray, var_b: np.ndarray) -> float:
    return var_a - var_b

@numba.jit(nopython=True)
def test2(var_a: np.ndarray, var_b: np.ndarray) -> float:
    return (var_a - var_b) ** 2

def test3(var_a: np.ndarray, var_b: np.ndarray) -> float:
    return (var_a - var_b) ** 2

a = np.array([2, 2]).astype(np.uint8).reshape(2, 1)
b = np.array([255, 255]).astype(np.uint8).reshape(2, 1)

test1(a, b)  # output: array([[3, 3]], dtype=uint8)
test2(a, b)  # output: array([[64009, 64009]], dtype=int64)
test3(a, b)  # output: array([[9, 9]], dtype=uint8)

What's unclear to me is why the python-only code preserves the data-type while the numba-code adjusts the returned type to int64? For my purpose, the numba result is ideal, but I don't understand why. I'm trying to avoid needing to .astype(int) all of my images, since this will eat a lot of RAM, when I'm only interested that the result of the subtraction be int (i.e., not unsigned).

So, why does numba "fixes" the datatype in test2()?

CodePudding user response:

Numba is a JIT compiler that first uses static type inference to deduce the type of the variables and then compile the function before it can be called. This means all literals like integers are typed before running anything. Numba choose to set the type of integer literals to int64 so to avoid overflows on 64-bit machines (and int32 on 32-bit machines). This means var_a - var_b is evaluated as an array of uint8 as expected. (var_a - var_b) ** 2 is like var_tmp ** np.uin64(2) where var_tmp is of type uint8[:]. In this case, The Numba type inference system needs to do a type promotion like in any statically typed language (eg. C/C ). Like most languages, Numba choose to do a relatively safe type promotion by casting the array to int64 because int64 include all the possible values of uint8. In practice, the type promotion can be quite unsafe un pathological cases: for example, when you mix uint64 values with int64 ones, the result can be a float64 with a large but more limited precision and no warning is raised. If you use (var_a - var_b) ** np.uint8(2), then the output type is the one you expect (ie. uint8) because there is no type promotion.

Numpy uses dynamic type inference. Moreover, integers have a variable length in Python so their type has to be defined by Numpy at runtime (not by CPython which only define the generic variable-sized int type). Numpy can thus adapt the type of integer literals based on their runtime value. For example, (np.uint8(7) * 1000_000).dtype is of type int32 on my machine, while (np.uint8(7) * 100_000_000_000).dtype is of type int64 (because the type of the right-most integer literal is set to int64 since it is too big for a 32-bit integer. This is something Numba cannot do because of JIT compilation [1]. Thus, the semantics is a bit different between Numba and Numpy. The type promotion should be the same though (so to get results as close to Numpy with Numba).

A good practice is to explicitly type arrays so to avoid sneaky overflow in both Numpy and Numba. Casting integers to specific types is generally not needed but it is also a good practice when the types are small and performance matters (eg. integer arithmetic with intended overflow like for hash computations).

Note you can do your_function.inspect_types() so to get additional information about the type inference (though it is not easy to read).

[1] In fact, it Numba could type integer literals based on their value, but not variables. The thing is it would be very unexpected for users to get different output types (and behaviour due to possible overflows) when users change literals to runtime variables.

  • Related