Is there any way to type a Numpy array as generic?
I'm currently working with Numpy 1.23.5 and Python 3.10, and I can't type hint the following example.
import numpy as np
import numpy.typing as npt
E = TypeVar("E") # Should be bounded to a numpy type
def double_arr(arr: npt.NDArray[E]) -> npt.NDArray[E]:
return arr * 2
What I expect
arr = np.array([1, 2, 3], dtype=np.int8)
double_arr(arr) # npt.NDAarray[np.int8]
arr = np.array([1, 2.3, 3], dtype=np.float32)
double_arr(arr) # npt.NDAarray[np.float32]
But I end up with the following error
arr: npt.NDArray[E]
^^^
Could not specialize type "NDArray[ScalarType@NDArray]"
Type "E@double_arr" cannot be assigned to type "generic"
"object*" is incompatible with "generic"
If i bound the E to numpy datatypes (np.int8, np.uint8, ...
) the type-checker fails to evaluate the multiplication due to the multiple data-types.
CodePudding user response:
Looking at the source, it seems the generic type variable used to parameterize numpy.dtype
of numpy.typing.NDArray
is bounded by numpy.generic
(and declared covariant). Thus any type argument to NDArray
must be a subtype of numpy.generic
, whereas your type variable is unbounded. This should work:
from typing import TypeVar
import numpy as np
from numpy.typing import NDArray
E = TypeVar("E", bound=np.generic, covariant=True)
def double_arr(arr: NDArray[E]) -> NDArray[E]:
return arr * 2
But there is another problem, which I believe lies in insufficient numpy stubs. An example of it is showcased in this issue. The overloaded operand (magic) methods like __mul__
somehow mangle the types. I just gave the code a cursory look right now, so I don't know what is missing. But mypy
will still complain about the last line in that code:
error: Returning Any from function declared to return "ndarray[Any, dtype[E]]" [no-any-return] error: Unsupported operand types for * ("ndarray[Any, dtype[E]]" and "int") [operator]
The workaround right now is to use the functions instead of the operands (via the dunder methods). In this case using numpy.multiply
instead of *
solves the issue:
from typing import TypeVar
import numpy as np
from numpy.typing import NDArray
E = TypeVar("E", bound=np.generic, covariant=True)
def double_arr(arr: NDArray[E]) -> NDArray[E]:
return np.multiply(arr, 2)
a = np.array([1, 2, 3], dtype=np.int8)
reveal_type(double_arr(a))
No more mypy
complaints and the type is revealed as follows:
numpy.ndarray[Any, numpy.dtype[numpy.signedinteger[numpy._typing._8Bit]]]
It's worth keeping an eye on that operand issue and maybe even report the specific error of Unsupported operand types for *
separately. I haven't found that in the issue tracker yet.
PS: Alternatively, you could use the *
operator and add a specific type: ignore
. That way you'll notice, if/once the annotation error is eventually fixed by numpy because mypy
complains about unused ignore-directives in strict mode.
def double_arr(arr: NDArray[E]) -> NDArray[E]:
return arr * 2 # type: ignore[operator,no-any-return]