Home > Net >  Is it possible to improve python performance for this code?
Is it possible to improve python performance for this code?

Time:12-01

I have a simple code that:

Read a trajectory file that can be seen as a list of 2D arrays (list of positions in space) stored in Y

I then want to compute for each pair (scipy.pdist style) the RMSD

My code works fine:

trajectory = read("test.lammpstrj", index="::")
m = len(trajectory)
#.get_positions() return a 2d numpy array
Y = np.array([snapshot.get_positions() for snapshot in trajectory])

b = [np.sqrt(((((Y[i]- Y[j])**2))*3).mean()) for i in range(m) for j in range(i   1, m)]

This code execute in 0.86 seconds using python3.10, using Julia1.8 the same kind of code execute in 0.46 seconds

I plan to have trajectory much larger (~ 200,000 elements), would it be possible to get a speed-up using python or should I stick to Julia?

CodePudding user response:

You've mentioned that snapshot.get_positions() returns some 2D array, suppose of shape (p, q). So I expect that Y is a 3D array with some shape (m, p, q), where m is the number of snapshots in the trajectory. You also expect m to scale rather high.

Let's see a basic way to speed up the distance calculation, on the setting m=1000:

import numpy as np

# dummy inputs
m = 1000
p, q = 4, 5
Y = np.random.randn(m, p, q)

# your current method
def foo():
    return [np.sqrt(((((Y[i]- Y[j])**2))*3).mean()) for i in range(m) for j in range(i   1, m)]

# vectorized approach -> compute the upper triangle of the pairwise distance matrix
def bar():
    u, v = np.triu_indices(Y.shape[0], 1)
    return np.sqrt((3 * (Y[u] - Y[v]) ** 2).mean(axis=(-1, -2)))

# Check for correctness

out_1 = foo()
out_2 = bar()
print(np.allclose(out_1, out_2))
# True

If we test the time required:

%timeit -n 10 -r 3 foo()
# 3.16 s ± 50.3 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)

The first method is really slow, it takes over 3 seconds for this calculation. Let's check the second method:

%timeit -n 10 -r 3 bar()
# 97.5 ms ± 405 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)

So we have a ~30x speedup here, which would make your large calculation in python much more feasible than using the original code. Feel free to test out with other sizes of Y to see how it scales compared to the original.


JIT

In addition, you can also try out JIT, mainly jax or numba. It is fairly simple to port the function bar with jax.numpy, for example:

import jax
import jax.numpy as jnp

@jax.jit
def jit_bar(Y):
    u, v = jnp.triu_indices(Y.shape[0], 1)
    return jnp.sqrt((3 * (Y[u] - Y[v]) ** 2).mean(axis=(-1, -2)))

# check for correctness

print(np.allclose(bar(), jit_bar(Y)))
# True

If we test the time of the jitted jnp op:

%timeit -n 10 -r 3 jit_bar(Y)
# 10.6 ms ± 678 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)

So compared to the original, we could reach even up to ~300x speed.

Note that not every operation can be converted to jax/jit so easily (this particular problem is conveniently suitable), so the general advice is to simply avoid python loops and use numpy's broadcasting/vectorization capabilities, like in bar().

CodePudding user response:

Stick to Julia.

If you already made it in a language which runs faster, why are you trying to use python in the first place?

CodePudding user response:

Your question is about speeding up Python, relative to Julia, so I'd like to offer some Julia code for comparison.

Since your data is most naturally expressed as a list of 4x5 arrays, I suggest expressing it as a vector of SMatrixes:

sumdiff2(A, B) = sum((A[i] - B[i])^2 for i in eachindex(A, B))
function dists(Y)
    M = length(Y)
    V = Vector{float(eltype(eltype(Y)))}(undef, sum(1:M-1))
    Threads.@threads for i in eachindex(Y)
        ii = sum(M-i 1:M-1)  # don't worry about this sum
        for j in i 1:lastindex(Y)
            ind = ii   (j-i)
            V[ind] = sqrt(3 * sumdiff2(Y[i], Y[j])/length(Y[i]))
        end
    end
    return V
end

using Random: randn
using StaticArrays: SMatrix
Ys = [randn(SMatrix{4,5,Float64}) for _ in 1:1000];

Benchmarks:

# single-threaded
julia> using BenchmarkTools
julia> @btime dists($Ys);
  6.561 ms (2 allocations: 3.81 MiB)

# multi-threaded with 6 cores
julia> @btime dists($Ys);
  1.606 ms (75 allocations: 3.82 MiB)

I was not able to install jax on my computer, but when comparing with @Mercury's numpy code I got

foo: 5.5seconds
bar: 179ms

i.e. approximately 3400x speedup over foo.

It is possible to write this as a one-liner at a ~2-3x performance cost.

  • Related