Home > other >  How to make this faster with numpy?
How to make this faster with numpy?

Time:09-28

Is there a way to make this code faster with numpy (or other)?

# prediciton.shape = (144, 192, 256, 4)

tt = np.zeros((144, 192, 256, 1), dtype = int)
for i in range(144):
    for j in range(192):
        for k in range(256):
            d = np.argmax(prediction[i,j,k])
            if d == 0:
                tt[i,j,k] = 0
            if d == 1:
                tt[i,j,k] = 30
            if d == 2:
                tt[i,j,k] = 150
            if d == 3:
                tt[i,j,k] = 250

CodePudding user response:

import numpy as np

prediction = np.random.uniform(0, 1, (144, 192, 256, 4))
pred_max = np.argmax(prediction, axis=-1)
tt = np.select([pred_max == 0, pred_max == 1, pred_max == 2, pred_max == 3], [0, 30, 150, 250])
print(tt)

seems to do what you want:

(144, 192, 256)
[[[  0 150   0 ... 250   0   0]
  [ 30 250 250 ...  30  30 150]
  [  0 250  30 ... 150 150  30]
  ...
  [150 150  30 ...  30 150   0]
  [150 250   0 ... 250   0   0]
  [ 30   0   0 ...   0   0 250]]

 [[ 30   0 250 ...  30 250 250]
  [ 30 150  30 ... 250 150 150]
  [ 30 250 250 ...  30   0 250]
...

A quick benchmark shows it takes about 0.277 seconds, where the original takes 13, so that's a 50x speed-up.

  • Related