Home > Enterprise >  Trying to compare different sized one-hot-encoded lists
Trying to compare different sized one-hot-encoded lists

Time:11-18

I have run an autoencoder model, and returned a dictionary with each output and it's label, using FashionMNIST. My goal is to print 10 images only for the dress and coat class (class labels 3 and 4). I have one-hot-encoded the labels such that the dress class appears as [0.,0,.0,1.,0.,0.,0.,0.,0.]. My dictionary output is:

print(pa). #dictionary is called pa
{'output': array([[1.5346111e-04, 2.3307074e-04, 2.8705355e-04, ..., 1.9890528e-04,
         1.8257453e-04, 2.0764180e-04],
        [1.9767908e-03, 1.5839143e-03, 1.7811939e-03, ..., 1.7838757e-03,
         1.4038634e-03, 2.3405524e-03],
        [5.8998094e-06, 6.9388111e-06, 5.8752844e-06, ..., 5.1715115e-06,
         4.4670110e-06, 1.2018012e-05],
        ...,
        [2.1034568e-05, 3.0344427e-05, 7.0048365e-05, ..., 9.4724113e-05,
         8.9003828e-05, 4.1828611e-05],
        [2.7930623e-06, 3.0393956e-06, 4.5835086e-06, ..., 3.8765144e-04,
         3.6324131e-05, 5.6411723e-06],
        [1.2453397e-04, 1.1948447e-04, 2.0121646e-04, ..., 1.0773790e-03,
         2.9582143e-04, 1.7229551e-04]], dtype=float32),
 'label': array([[1., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 1., 0.],
        [0., 0., 0., ..., 1., 0., 0.],
        ...,
        [1., 0., 0., ..., 0., 0., 0.],
        [0., 0., 1., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)}

I am trying to run a for loop, where if the pa['label'] is equal to a certain one-hot-encoded array, I plot the corresponding pa['output'].

for i in range(len(pa['label'])):
    if pa['label'][i] == np.array([0.,0.,0.,1.,0.,0.,0.,0.,0.]):
        print(pa['lable'][i])
#         plt.imshow(pa['output'][i].reshape(28,28))
#         plt.show()

However, I get a warning(?):

/opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:2: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.

I have also tried making a list of arrays of the one-hot-encoded arrays i want to plot and trying to compare my dictionary label to this array (different sized arrays):

clothing_array = np.array([[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
                           [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])


for i in range(len(pa['label'])):
     if (pa['label'][i] == clothing_array[i]).any():
          plt.imshow(pa['output'][i].reshape(28,28))
          plt.show()

However, it plots a picture of a tshirt, a bag, and then i get the error

IndexError: index 2 is out of bounds for axis 0 with size 2

Which i understand since clothing_array only has two indices. But obviously this code is wrong since I want to print ONLY dress and coat. I don't know why it's printing these images and i don't know how to fix it. Any help or clarifying questions are more than welcome.

Here are the first ten arrays of my dictionary labels:

array([[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)

CodePudding user response:

I will post an example here. Here we have two arrays for you x is the label array and y the clothing . You can get in z the ones that are identical (the indexes). Finally by using the matching_indexes you can collect the onces you want from output and plot them

x = np.array([[1., 0., 0., 0., 0., 0., 0.],
                [0., 1., 0., 0., 0., 0., 0.],
                [0., 0., 0., 0., 1., 0., 0.],
                [1., 0., 0., 0., 0., 0., 0.],
                [0., 0., 1., 0., 0., 0., 0.],
                [0., 0., 0., 1., 0., 0., 0.]])

y = np.array([[1.,0.,0.,0.,0.,0.,0.]])

z= np.multiply(x,y)
matching_indexes = np.where(z.any(axis=1))[0]
  • Related