Home > OS >  How to optimize Mean Square Displacement for several particles in two dimensions in python?
How to optimize Mean Square Displacement for several particles in two dimensions in python?

Time:10-30

I want to calculate the mean square displacement for several particles, defined as: enter image description here

where i is the index for the particle, Dt is the time interval, t is the time, and vec(x) is the position of the particles in two dimensions. We do an average for all possible times t.

I have managed to implement it with numpy. Note that pos is a np.array with three axis: (particles, time, coordinate).

import numpy as np
import matplotlib.pyplot as plt
import time

#Initialize data
np.random.seed(1)
nTime = 10**4
nParticles = 3
pos = np.zeros((nParticles, nTime, 2)) #Axis: particles, times, coordinates
for t in range(1, nTime):
    pos[:, t, :] = pos[:, t-1, :]   ( np.random.random((nParticles, 2)) - 0.5)

#MSD calculation
def MSD_direct(pos):
    Dt_r = np.arange(1, pos.shape[1]-1)
    MSD = np.empty((nParticles, len(Dt_r)))
    dMSD = np.empty((nParticles,len(Dt_r)))
    for k, Dt in enumerate(Dt_r):
        SD = np.sum((pos[:, Dt:,:] - pos[:, 0:-Dt,:])**2, axis = -1)
        MSD[:,k] = np.mean( SD , axis = 1)
        dMSD[:,k] = np.std( SD, axis = 1 ) / np.sqrt(SD.shape[1])

    return Dt_r, MSD, dMSD

start_time = time.time()
Dt_r, MSD_d, dMSD_d = MSD_direct(pos)
print("MSD_direct -- Time: %s s\n" % (time.time() - start_time))

#Plots
plt.figure()
for i in range(nParticles):
    plt.plot(pos[i,:,0])    
plt.xlabel('t')
plt.ylabel('x')
plt.savefig('pos_x.png', dpi = 300)

plt.figure()
for i in range(nParticles):
    plt.plot(pos[i,:,1])    
plt.xlabel('t')
plt.ylabel('y')
plt.savefig('pos_y.png', dpi = 300)

plt.figure()
for i in range(nParticles):
    plt.fill_between(Dt_r, MSD_d[i,:] dMSD_d[i,:], MSD_d[i,:] - dMSD_d[i,:], alpha = 0.5)
    plt.plot(Dt_r, MSD_d[i,:])
plt.xlabel('Dt')
plt.ylabel('MSD')
plt.savefig('MSD.png', dpi = 300)

Output:

MSD_direct -- Time: 7.793087720870972 s

enter image description here enter image description here enter image description here

However, I would like to optimize this code if possible. There is still a loop for Dt, I do not know how could I remove it and vectorize the program fully using numpy.


I also rewrote the calculation using numba, managing around a factor two of improvement from the previous code. I wonder if it is still possible to further improve it.

import numba as nb
@nb.jit(fastmath=True,parallel=True)
def MSD_numba(pos):
    Dt_r = np.arange(1, pos.shape[1]-1)
    MSD = np.empty((nParticles, len(Dt_r)))
    dMSD = np.empty((nParticles,len(Dt_r)))
    for i in nb.prange(nParticles):  
        for Dt in Dt_r:
            SD = (pos[i, Dt:, 0] - pos[i, 0:-Dt, 0])**2   (pos[i, Dt:, 1] - pos[i, 0:-Dt, 1])**2
            MSD[i, Dt-1] = np.mean( SD )
            dMSD[i, Dt-1] = np.std( SD ) / np.sqrt(len(SD)) 
    return Dt_r, MSD, dMSD

start_time = time.time()
Dt_r, MSD_n, dMSD_n = MSD_numba(pos)
print("MSD_numba -- Time: %s s" % (time.time() - start_time))
print("MSD_numba -- All close to MSD_direct: %r\n" %(np.allclose(MSD_n, MSD_d) )  )

Output:

MSD_numba -- Time: 4.520232915878296 s
MSD_numba -- All close to MSD_direct: True

Note: this type of question has been asked in several posts already, but they use different definitions (Mean square displacement python, Mean squared displacement, Mean square displacement for n-dim matrix python), they do not have an answer (Mean square displacement in Python), they just use one particle (Computing mean square displacement using python and FFT, Mean square displacement of a 1d random walk in python), they use pandas (Vectorized calculation of Mean Square Displacement in Python, Speedup MSD calculation in Python), etc.

CodePudding user response:

Adapting the answer from Computing mean square displacement using python and FFT that uses FFT transforms, I managed to do this calculation faster by two orders of magnitude:

def MSD_fft(pos):
    nTime=pos.shape[1]        

    S2 = np.sum ( np.fft.ifft( np.abs(np.fft.fft(pos, n=2*nTime, axis = -2))**2, axis = -2  )[:,:nTime,:].real , axis = -1 ) / (nTime-np.arange(nTime)[None,:] )

    D=np.square(pos).sum(axis=-1)
    D=np.append(D, np.zeros((pos.shape[0], 1)), axis = -1)
    S1 = ( 2 * np.sum(D, axis = -1)[:,None] - np.cumsum( np.insert(D[:,0:-1], 0, 0, axis = -1)   np.flip(D, axis = -1), axis = -1 ) )[:,:-1] / (nTime - np.arange(nTime)[None,:] )

    MSD = S1-2*S2

    Dt_r = np.arange(1, pos.shape[1]-1)
    MSD = MSD[:,Dt_r]
    return Dt_r, MSD

start_time = time.time()
Dt_r, MSD_f = MSD_fft(pos)
print("MSD_fft -- Time: %s s" % (time.time() - start_time))
print("MSD_fft -- All close to MSD_direct: %r\n" %(np.allclose(MSD_f, MSD_d) )  )

Output:

MSD_direct -- Time: 2.1434285640716553 s

MSD_numba -- Time: 1.532573938369751 s
MSD_numba -- All close to MSD_direct: True

MSD_fft -- Time: 0.007384061813354492 s
MSD_fft -- All close to MSD_direct: True

Though I haven not been able to calculate the error using this method. However, provided we have enough statistics, the error should keep small. In fact, in the plots you can not distinguish it.


Generalized function for any n dimensional array

I generalized the previous function for a pos given by any n dimensional array, you just need to specify the axes of time and the coordinates:

def MSD_fft_ax(pos, axis_time, axis_coord):
    nTime=pos.shape[axis_time]        

    S2 = np.sum (  np.fft.ifft( np.abs(np.fft.fft(pos, n=2*nTime, axis = axis_time))**2, axis = axis_time ).take(range(nTime), axis = axis_time).real, axis = axis_coord )
    
    D=np.square(pos).sum(axis=axis_coord)

    if axis_coord % pos.ndim < axis_time % pos.ndim: axis_time -= 1

    shape_t = [nTime if ax==axis_time % D.ndim else 1 for ax, s in enumerate(D.shape)]
    shape_non_t = [1 if ax==axis_time % D.ndim else s for ax, s in enumerate(D.shape)]

    D=np.append(D, np.zeros( shape_non_t ), axis = axis_time)
    S1 = ( 2 * np.sum(D, axis = axis_time).reshape(shape_non_t) - np.cumsum( np.insert(D.take(np.arange(0,nTime), axis=axis_time), 0, 0, axis = axis_time)   np.flip(D, axis = axis_time), axis = axis_time ) ).take(np.arange(0,nTime), axis = axis_time) 

    MSD = ( S1-2*S2 ) / ( nTime-np.arange(nTime).reshape(shape_t) )

    Dt_r = np.arange(1, nTime-1)
    MSD = MSD.take(Dt_r, axis = axis_time)
    return Dt_r, MSD

start_time = time.time()
Dt_r, MSD_fax = MSD_fft_ax(pos, axis_time = 1, axis_coord=-1)
print("MSD_fft_ax -- Time: %s s" % (time.time() - start_time))
print("MSD_fft_ax -- All close to MSD_direct: %r\n" %(np.allclose(MSD_fax, MSD_d) )  )

Output:

MSD_fft_ax -- Time: 0.009054422378540039 s
MSD_fft_ax -- All close to MSD_direct: True
  • Related