I've been trying to remove specific images from the CIFAR-10 train set but have had no luck yet. Hers's what I've tried so far:
data = tfds.as_numpy(tfds.load(name=FLAGS.dataset, batch_size=-1, data_dir=DATA_DIR))
inputs = data['train']['image']
labels = data['train']['label']
inputs = (inputs/127.5)-1
inputs = np.delete(inputs, [0, 4, 3, 2])
labels = np.delete(labels, [0, 4, 3, 2])
I've a list of indices of specific images [0, 4, 3, 2]
that I would like to remove. The shape of inputs
is (50000, 32, 32, 3)
and the shape of labels
is (50000,)
. The above method is simply not working as for some reason the shape of input
gets nearly 5x times bigger. I'd appreciate any help here.
CodePudding user response:
The documentation for np.delete says:
axis: int, optional
The axis along which to delete the subarray defined by obj. If axis is None, obj is applied to the flattened array.
So specify the axis of samples when deleting the elements
inputs = np.delete(inputs, [0, 4, 3, 2], axis=0)