Is there a way to make this code faster with numpy (or other)?
# prediciton.shape = (144, 192, 256, 4)
tt = np.zeros((144, 192, 256, 1), dtype = int)
for i in range(144):
for j in range(192):
for k in range(256):
d = np.argmax(prediction[i,j,k])
if d == 0:
tt[i,j,k] = 0
if d == 1:
tt[i,j,k] = 30
if d == 2:
tt[i,j,k] = 150
if d == 3:
tt[i,j,k] = 250
CodePudding user response:
import numpy as np
prediction = np.random.uniform(0, 1, (144, 192, 256, 4))
pred_max = np.argmax(prediction, axis=-1)
tt = np.select([pred_max == 0, pred_max == 1, pred_max == 2, pred_max == 3], [0, 30, 150, 250])
print(tt)
seems to do what you want:
(144, 192, 256)
[[[ 0 150 0 ... 250 0 0]
[ 30 250 250 ... 30 30 150]
[ 0 250 30 ... 150 150 30]
...
[150 150 30 ... 30 150 0]
[150 250 0 ... 250 0 0]
[ 30 0 0 ... 0 0 250]]
[[ 30 0 250 ... 30 250 250]
[ 30 150 30 ... 250 150 150]
[ 30 250 250 ... 30 0 250]
...
A quick benchmark shows it takes about 0.277 seconds, where the original takes 13, so that's a 50x speed-up.