Home > Enterprise >  Applying Filter to Multi Dimensional Numpy Array ,eg: Cifar10 Data
Applying Filter to Multi Dimensional Numpy Array ,eg: Cifar10 Data

Time:07-12

from keras.datasets import cifar10
# load dataset
(trainX, trainy), (testX, testy) = cifar10.load_data()
# summarize loaded dataset
print('Train: X=%s, y=%s' % (trainX.shape, trainy.shape))
print('Test: X=%s, y=%s' % (testX.shape, testy.shape))

Train: X=(50000, 32, 32, 3), y=(50000, 1)

Test: X=(10000, 32, 32, 3), y=(10000, 1)

trainMask = (trainy == 1) | (trainy == 8) | (trainy == 9)
testMask  = (testy == 1)  | (testy == 8)  | (testy == 9)

HOW TO FILTER THE TRAIN AND TEST BASED ON MASK ? like..

trainX = trainX[trainMask] ,

testX = testX[testMask]

trainy = trainy[trainMask] .. One Dimention Works.. Not trainX = trainX[trainMask]

CodePudding user response:

what you're looking for is np.where(). see the code

TrainX = TrainX[np.where(trainMask)]
TestX = TestX[np.where(testMask)]

CodePudding user response:

This should do the trick:


trainX = trainX[trainMask.flatten()]

  • Related