Home > database >  type hint npt.NDArray number of axis
type hint npt.NDArray number of axis

Time:07-05

Given I have the number of axes, can I specify the number of axes to the type hint npt.NDArray (from import numpy.typing as npt)

i.e. if I know it is a 3D array, how can I do npt.NDArray[3, np.float64]

CodePudding user response:

On Python 3.9 and 3.10 the following does the job for me:

data = [[1, 2, 3], [4, 5, 6]]
arr: np.ndarray[Tuple[Literal[2], Literal[3]], np.dtype[np.int_]] = np.array(data)

It is a bit cumbersome, but you might follow numpy issue #16544 for future development on easier specification.

In particular, for now you must declare the full shape and can't only declare the rank of the array.

In the future something like ndarray[Shape[:, :, :], dtype] should be available.

  • Related