Home > front end >  Numba Rotation Matrix with Multidimensional Arrays
Numba Rotation Matrix with Multidimensional Arrays

Time:01-07

I'm trying to use Numba to accelerate some functions, in particular a function that performs a 3D rotation given three angles, as shown below:

import numpy as np
from numba import jit

@jit(nopython=True)
def rotation_matrix(theta_x, theta_y, theta_z):
    # Convert to radians. To ensure counter-clockwise (ccw) rotations, take
    # negative of angles.
    theta_x_rad = -np.radians(theta_x)
    theta_y_rad = -np.radians(theta_y)
    theta_z_rad = -np.radians(theta_z)
    # Define rotation matrices (yaw, pitch, roll)
    Rx = np.array([[1, 0,0],
                    [0, np.cos(theta_x_rad),-np.sin(theta_x_rad)],
                    [0, np.sin(theta_x_rad),np.cos(theta_x_rad) ]
                    ])


    Ry = np.array([[ np.cos(theta_y_rad), 0,np.sin(theta_y_rad)],
                    [ 0,1,0],
                    [-np.sin(theta_y_rad), 0,np.cos(theta_y_rad)]
                    ])


    Rz = np.array([[np.cos(theta_z_rad),-np.sin(theta_z_rad),0],
                    [np.sin(theta_z_rad),np.cos(theta_z_rad),0],
                    [0,0,1]
                    ])


    # Compute total rotation matrix
    R  = np.dot(Rz, np.dot( Ry, Rx ))
    #
    return R

The function is relatively simple, but when Numba calls it, it throws an error when I try to define Rx.It appears that Numba has a problem with multidimensional arrays (?). I'm not sure how to modify this such that Numba could utilize it. Any help would be appreciated.

CodePudding user response:

The problem comes from the mix between integers and floats typed values. Numba try to defined a type of the array and found that [1, 0, 0] is a list of integer but the overall array is initialized with both a list of integer and a list of floats. The type inference is confused and raised an error because the overall type is ambiguous. You can write 1.0 and 0.0 instead of 1 and 0 so to fix the issue. More generally, specifying the dtype of arrays is generally a good practice, especially in Numba due to the type inference.

If you want to avoid compilation errors at runtime when the function is called the first time, then you can precise the parameter types. Note that you can use njit instead of nopython=True (shorter). The resulting decorator should be @njit('(float64, float64, float64)').

  •  Tags:  
  • Related