I want to calculate the mean square displacement for several particles, defined as:
where i
is the index for the particle, Dt
is the time interval, t
is the time, and vec(x)
is the position of the particles in two dimensions. We do an average for all possible times t
.
I have managed to implement it with numpy. Note that pos
is a np.array
with three axis: (particles, time, coordinate)
.
import numpy as np
import matplotlib.pyplot as plt
import time
#Initialize data
np.random.seed(1)
nTime = 10**4
nParticles = 3
pos = np.zeros((nParticles, nTime, 2)) #Axis: particles, times, coordinates
for t in range(1, nTime):
pos[:, t, :] = pos[:, t-1, :] ( np.random.random((nParticles, 2)) - 0.5)
#MSD calculation
def MSD_direct(pos):
Dt_r = np.arange(1, pos.shape[1]-1)
MSD = np.empty((nParticles, len(Dt_r)))
dMSD = np.empty((nParticles,len(Dt_r)))
for k, Dt in enumerate(Dt_r):
SD = np.sum((pos[:, Dt:,:] - pos[:, 0:-Dt,:])**2, axis = -1)
MSD[:,k] = np.mean( SD , axis = 1)
dMSD[:,k] = np.std( SD, axis = 1 ) / np.sqrt(SD.shape[1])
return Dt_r, MSD, dMSD
start_time = time.time()
Dt_r, MSD_d, dMSD_d = MSD_direct(pos)
print("MSD_direct -- Time: %s s\n" % (time.time() - start_time))
#Plots
plt.figure()
for i in range(nParticles):
plt.plot(pos[i,:,0])
plt.xlabel('t')
plt.ylabel('x')
plt.savefig('pos_x.png', dpi = 300)
plt.figure()
for i in range(nParticles):
plt.plot(pos[i,:,1])
plt.xlabel('t')
plt.ylabel('y')
plt.savefig('pos_y.png', dpi = 300)
plt.figure()
for i in range(nParticles):
plt.fill_between(Dt_r, MSD_d[i,:] dMSD_d[i,:], MSD_d[i,:] - dMSD_d[i,:], alpha = 0.5)
plt.plot(Dt_r, MSD_d[i,:])
plt.xlabel('Dt')
plt.ylabel('MSD')
plt.savefig('MSD.png', dpi = 300)
Output:
MSD_direct -- Time: 7.793087720870972 s
However, I would like to optimize this code if possible. There is still a loop for Dt
, I do not know how could I remove it and vectorize the program fully using numpy.
I also rewrote the calculation using numba, managing around a factor two of improvement from the previous code. I wonder if it is still possible to further improve it.
import numba as nb
@nb.jit(fastmath=True,parallel=True)
def MSD_numba(pos):
Dt_r = np.arange(1, pos.shape[1]-1)
MSD = np.empty((nParticles, len(Dt_r)))
dMSD = np.empty((nParticles,len(Dt_r)))
for i in nb.prange(nParticles):
for Dt in Dt_r:
SD = (pos[i, Dt:, 0] - pos[i, 0:-Dt, 0])**2 (pos[i, Dt:, 1] - pos[i, 0:-Dt, 1])**2
MSD[i, Dt-1] = np.mean( SD )
dMSD[i, Dt-1] = np.std( SD ) / np.sqrt(len(SD))
return Dt_r, MSD, dMSD
start_time = time.time()
Dt_r, MSD_n, dMSD_n = MSD_numba(pos)
print("MSD_numba -- Time: %s s" % (time.time() - start_time))
print("MSD_numba -- All close to MSD_direct: %r\n" %(np.allclose(MSD_n, MSD_d) ) )
Output:
MSD_numba -- Time: 4.520232915878296 s
MSD_numba -- All close to MSD_direct: True
Note: this type of question has been asked in several posts already, but they use different definitions (Mean square displacement python, Mean squared displacement, Mean square displacement for n-dim matrix python), they do not have an answer (Mean square displacement in Python), they just use one particle (Computing mean square displacement using python and FFT, Mean square displacement of a 1d random walk in python), they use pandas (Vectorized calculation of Mean Square Displacement in Python, Speedup MSD calculation in Python), etc.
CodePudding user response:
Adapting the answer from Computing mean square displacement using python and FFT that uses FFT transforms, I managed to do this calculation faster by two orders of magnitude:
def MSD_fft(pos):
nTime=pos.shape[1]
S2 = np.sum ( np.fft.ifft( np.abs(np.fft.fft(pos, n=2*nTime, axis = -2))**2, axis = -2 )[:,:nTime,:].real , axis = -1 ) / (nTime-np.arange(nTime)[None,:] )
D=np.square(pos).sum(axis=-1)
D=np.append(D, np.zeros((pos.shape[0], 1)), axis = -1)
S1 = ( 2 * np.sum(D, axis = -1)[:,None] - np.cumsum( np.insert(D[:,0:-1], 0, 0, axis = -1) np.flip(D, axis = -1), axis = -1 ) )[:,:-1] / (nTime - np.arange(nTime)[None,:] )
MSD = S1-2*S2
Dt_r = np.arange(1, pos.shape[1]-1)
MSD = MSD[:,Dt_r]
return Dt_r, MSD
start_time = time.time()
Dt_r, MSD_f = MSD_fft(pos)
print("MSD_fft -- Time: %s s" % (time.time() - start_time))
print("MSD_fft -- All close to MSD_direct: %r\n" %(np.allclose(MSD_f, MSD_d) ) )
Output:
MSD_direct -- Time: 2.1434285640716553 s
MSD_numba -- Time: 1.532573938369751 s
MSD_numba -- All close to MSD_direct: True
MSD_fft -- Time: 0.007384061813354492 s
MSD_fft -- All close to MSD_direct: True
Though I haven not been able to calculate the error using this method. However, provided we have enough statistics, the error should keep small. In fact, in the plots you can not distinguish it.
Generalized function for any n dimensional array
I generalized the previous function for a pos
given by any n dimensional array, you just need to specify the axes of time and the coordinates:
def MSD_fft_ax(pos, axis_time, axis_coord):
nTime=pos.shape[axis_time]
S2 = np.sum ( np.fft.ifft( np.abs(np.fft.fft(pos, n=2*nTime, axis = axis_time))**2, axis = axis_time ).take(range(nTime), axis = axis_time).real, axis = axis_coord )
D=np.square(pos).sum(axis=axis_coord)
if axis_coord % pos.ndim < axis_time % pos.ndim: axis_time -= 1
shape_t = [nTime if ax==axis_time % D.ndim else 1 for ax, s in enumerate(D.shape)]
shape_non_t = [1 if ax==axis_time % D.ndim else s for ax, s in enumerate(D.shape)]
D=np.append(D, np.zeros( shape_non_t ), axis = axis_time)
S1 = ( 2 * np.sum(D, axis = axis_time).reshape(shape_non_t) - np.cumsum( np.insert(D.take(np.arange(0,nTime), axis=axis_time), 0, 0, axis = axis_time) np.flip(D, axis = axis_time), axis = axis_time ) ).take(np.arange(0,nTime), axis = axis_time)
MSD = ( S1-2*S2 ) / ( nTime-np.arange(nTime).reshape(shape_t) )
Dt_r = np.arange(1, nTime-1)
MSD = MSD.take(Dt_r, axis = axis_time)
return Dt_r, MSD
start_time = time.time()
Dt_r, MSD_fax = MSD_fft_ax(pos, axis_time = 1, axis_coord=-1)
print("MSD_fft_ax -- Time: %s s" % (time.time() - start_time))
print("MSD_fft_ax -- All close to MSD_direct: %r\n" %(np.allclose(MSD_fax, MSD_d) ) )
Output:
MSD_fft_ax -- Time: 0.009054422378540039 s
MSD_fft_ax -- All close to MSD_direct: True