Home > other >  Python argmax of dot product of weighted matrix and vector (mnist)
Python argmax of dot product of weighted matrix and vector (mnist)

Time:07-11

What does argmax mean in this context? I am following the tutorial in this colab notebook: https://colab.research.google.com/github/chokkan/deeplearningclass/blob/master/mnist.ipynb

for x, y in zip(Xtrain, Ytrain):
        y_pred = np.argmax(np.dot(W, x))

It looks like this is saying that for every record x and its truth value y, in the vectors Xtrain and Ytrain, take the max value of the dot product of the weighted matrix W and the record x. Does this mean it takes the max of the weighted matrix?

It also looks like 1 was appended to the flattened vector:

def image_to_vector(X):
    X = np.reshape(X, (len(X), -1))     # Flatten: (N x 28 x 28) -> (N x 784)
    return np.c_[X, np.ones(len(X))]    # Append 1: (N x 784) -> (N x 785)

Xtrain = image_to_vector(data['train_x'])

Why would that be?

Thank you!

CodePudding user response:

For simplicity, you can treat it as a sort of y = W * x bias. Additional column of ones is independent on the input, thus working as bias.

Now, our weight matrix W represents a fully connected layer with 785 (28*28 1) inputs and 10 outputs (7850 weights total). The dot product of W and x is a vector of length 10, containing the scores for each possible class (digit in MNIST case). Applying argmax, we get the index with the highest score (our prediction).

  • Related