Home > Net >  How to type hint a generic numpy array?
How to type hint a generic numpy array?

Time:12-01

Is there any way to type a Numpy array as generic?

I'm currently working with Numpy 1.23.5 and Python 3.10, and I can't type hint the following example.

import numpy as np
import numpy.typing as npt


E = TypeVar("E") # Should be bounded to a numpy type

def double_arr(arr: npt.NDArray[E]) -> npt.NDArray[E]:
    return arr * 2

What I expect

arr = np.array([1, 2, 3], dtype=np.int8)
double_arr(arr) # npt.NDAarray[np.int8]

arr = np.array([1, 2.3, 3], dtype=np.float32)
double_arr(arr) # npt.NDAarray[np.float32]

But I end up with the following error

arr: npt.NDArray[E]
                ^^^
Could not specialize type "NDArray[ScalarType@NDArray]"
  Type "E@double_arr" cannot be assigned to type "generic"
    "object*" is incompatible with "generic"

If i bound the E to numpy datatypes (np.int8, np.uint8, ...) the type-checker fails to evaluate the multiplication due to the multiple data-types.

CodePudding user response:

Looking at the source, it seems the generic type variable used to parameterize numpy.dtype of numpy.typing.NDArray is bounded by numpy.generic (and declared covariant). Thus any type argument to NDArray must be a subtype of numpy.generic, whereas your type variable is unbounded. This should work:

from typing import TypeVar

import numpy as np
from numpy.typing import NDArray


E = TypeVar("E", bound=np.generic, covariant=True)


def double_arr(arr: NDArray[E]) -> NDArray[E]:
    return arr * 2

But there is another problem, which I believe lies in insufficient numpy stubs. An example of it is showcased in this issue. The overloaded operand (magic) methods like __mul__ somehow mangle the types. I just gave the code a cursory look right now, so I don't know what is missing. But mypy will still complain about the last line in that code:

error: Returning Any from function declared to return "ndarray[Any, dtype[E]]"  [no-any-return]
error: Unsupported operand types for * ("ndarray[Any, dtype[E]]" and "int")  [operator]

The workaround right now is to use the functions instead of the operands (via the dunder methods). In this case using numpy.multiply instead of * solves the issue:

from typing import TypeVar

import numpy as np
from numpy.typing import NDArray


E = TypeVar("E", bound=np.generic, covariant=True)


def double_arr(arr: NDArray[E]) -> NDArray[E]:
    return np.multiply(arr, 2)


a = np.array([1, 2, 3], dtype=np.int8)
reveal_type(double_arr(a))

No more mypy complaints and the type is revealed as follows:

numpy.ndarray[Any, numpy.dtype[numpy.signedinteger[numpy._typing._8Bit]]]

It's worth keeping an eye on that operand issue and maybe even report the specific error of Unsupported operand types for * separately. I haven't found that in the issue tracker yet.


PS: Alternatively, you could use the * operator and add a specific type: ignore. That way you'll notice, if/once the annotation error is eventually fixed by numpy because mypy complains about unused ignore-directives in strict mode.

def double_arr(arr: NDArray[E]) -> NDArray[E]:
    return arr * 2  # type: ignore[operator,no-any-return]
  • Related