Home > OS >  Numba is "only" improving my code by a factor of 4. Can it do better?
Numba is "only" improving my code by a factor of 4. Can it do better?

Time:10-02

I am working on the N-body problem (i.e.: given the positions of N bodies in space, I need to compute their mutual interaction).

For N=10,000 particles, my non-jitted function takes about 84 seconds while my jitted function takes about 22 seconds. However, based on some articles and videos, NUMBA was advertised to me as a tool that can improve my code by 1 to 2 orders of magnitude !

Therefore, I would like to share the code with you and ask if there is any room for improvement:

import numpy as np
from numba import jit, njit
import time
import timeit

def compute_acc( pos, mass, G, softening ):
    """ Computes the acceleration of N bodies

    Args:
        pos (type=np.array, size= Nx3): x, y, z positions of the N particles
        mass (type=np.array, size= Nx1): mass of the particles
        G (float): Newton's Gravitational constant
        softening (float): softening parameter

    Returns:
        acc (type=np.array, size= Nx3): ax, ay, az accelerations of the N particles
    """

    # positions r = [x,y,z] for all particles
    x = pos[:,0:1]
    y = pos[:,1:2]
    z = pos[:,2:3]

    # matrix that stores all pairwise particle separations: r_j - r_i
    dx = x.T - x
    dy = y.T - y
    dz = z.T - z

    # matrix that stores 1/r^3 for all particle pairwise particle separations 
    inv_r3 = (dx**2   dy**2   dz**2   softening**2)**(-1.5)

    ax = G * (dx * inv_r3) @ mass
    ay = G * (dy * inv_r3) @ mass
    az = G * (dz * inv_r3) @ mass

    # pack together the acceleration components
    acc = np.hstack((ax,ay,az))

    return acc


#Define the jitted version of compute_acc
compute_acc_jit= njit(cache=True,fastmath=True) (compute_acc)


#Initialize the parameters to test the functions
np.random.seed(123) 
N=10000
pos=np.random.uniform(low=-10, high=10, size=(N,3)) # Random uniform positions
mass=np.random.uniform(low=1, high=20, size=(N,1)) # Random uniform masses
G=1.0
softening=0.1

# Compute Non-Jitted time:
T1= min(timeit.repeat(stmt='compute_acc(pos, mass, G, softening)',\
                  timer=time.perf_counter,repeat=3, number=1,globals=globals()) )

print("Non-JIT time=",T1,"\n")

# Compute Jitted time:
T2= min(timeit.repeat(stmt='compute_acc_jit(pos, mass, G, softening)',\
                  timer=time.perf_counter,repeat=3, number=1,globals=globals()) )

print("JIT time=",T2,"\n")

PS: I suspect that the only way to reach an improvement of 1 to 2 orders of magnitude would be to use multi-threading or to use a GPU. Is that right ?

Thank you ! 


Reference: My code is taken from Dr. Philip Mocz GitHub: https://github.com/pmocz/nbody-python

CodePudding user response:

Analysis and optimizations

The main problem is that the code create several large temporary arrays stored in RAM while the RAM is slow. Moreover, Numba is not capable of fusing operations and optimizing out temporary array so the use of Numba does not help much in the provided code. For Numba to be fast, you need to use plain loops and compute data on the fly so to avoid reading/writing in memory as much as possible.

Additionally, you can optimize/factorize some operations like the matrix multiplication and the multiplication by G so to write a faster implementation. For example, G * ((dx * inv_r3) @ mass) is faster than G * (dx * inv_r3) @ mass since the array to multiply by G in the first case is far smaller than in the second case.

Moreover, x, y and z can be copied so to create contiguous arrays in memory instead of strided views. Operation on contiguous arrays are generally faster since items can be computed in a SIMD way (GPUs love this) and fetched more efficiently from the RAM.

Finally, you can use multiple Numba threads with parallel=True and prange. One can also specify the signature of the function so to avoid the compilation time during the first run (and bias the benchmark a bit).

Here is the resulting code:

import numba as nb

@njit('(float64[:,:], float64[:,:], float64, float64)', cache=True, fastmath=True, parallel=True)
def compute_acc_fast(pos, mass, G, softening):
    n = pos.shape[0]

    # Copy the array view so for the next loop to be faster
    x = pos[:,0].copy()
    y = pos[:,1].copy()
    z = pos[:,2].copy()

    # Ensure mass is a contiguous 1D array (cheap operation)
    assert mass.shape[1] == 1
    contig_mass = mass[:,0].copy()

    acc = np.empty((n, 3), pos.dtype)

    for i in nb.prange(n):
        ax, ay, az = 0.0, 0.0, 0.0

        for j in range(n):
            dx = x[j] - x[i]
            dy = y[j] - y[i]
            dz = z[j] - z[i]
            tmp = (dx**2   dy**2   dz**2   softening**2)
            factor = contig_mass[j] / (tmp * np.sqrt(tmp))
            ax  = dx * factor
            ay  = dy * factor
            az  = dz * factor

        acc[i, 0] = G * ax
        acc[i, 1] = G * ay
        acc[i, 2] = G * az

    return acc

Results and discussion

Here are results on my i5-9600KF (6-core) processor:

Non-JIT time:           5.906 s   (  x1.0)
Initial JIT time:       1.665 s   (  x3.5)
Optimized JIT time:     0.049 s   (x120.6)
Optimal lower-bound:    0.024 s   (x246.1)

The computation is 121 times faster with the optimized implementation. It is sub-optimal since Numba failed to generate a code using SIMD instructions (certainly due to the = at the end of the loop). One can write a loop iterating on chunk of 8 items so to help the compiler to do that. Once vectorized, the code should be very close to the optimal time (0.025s). The current code is bound by the computation of the square root and the division on my machine.

If this is not enough, you can try to use simple-precision floating point so to get a code ~2 times faster though it will be less precise. If this is still not fast enough, then you can write a C code using the _mm256_rsqrt_ps x86-64 intrinsics which is meant to compute an approximation of 1/sqrt(x) (with a relative precision of 3.7e-4). An alternative solution is to use GPUs, but one should note that a server-side GPUs is generally required to compute double-precision numbers efficiently.

  • Related