Home > Software design >  How to square a row in NumPy to go from a 2-d array to a 3-d one where each row was squared?
How to square a row in NumPy to go from a 2-d array to a 3-d one where each row was squared?

Time:01-05

I am trying to figure out a way to get the rows of a 2-d matrix squared. The behaviour I would like to have is something like this:

in[1] import numpy as np
in[2] a = np.array([[1,2,3],
                [4,5,6]])
in[3] some_function(a) # for each row, row.reshape(-1,1); row @ row.T
out[1] array([[[ 1,  2,  3],
        [ 2,  4,  6],
        [ 3,  6,  9]],

       [[16, 20, 24],
        [20, 25, 30],
        [24, 30, 36]]])

I need this to make a softmax derivative for auto diff in a manual implementation of a feed-forward neural network. The same derivative would look like this for a point:

in[4] def softmax_derivative(x):
in[5]     s = x.reshape(-1,1)
in[6]     return np.diagflat(s) - np.dot(s,s.T)

Instead of np.diagflat I am using:

in[7] matrix = np.array([[1,2,3],
                         [4,5,6])
in[8] matrix.shape
out[2] (2,3)
in[9] Id = np.eye(matrix.shape[-1])
in[10] (matrix[...,np.newaxis] * Id).shape
out[3] (2,3,3)

The reason I want a 3-d array of the squared rows is to subtract it from the 3-d array of the diagonal rows which I get in the same way as in the above example.

While I know that I can get the same multiplication result from

in[11] def get_squared_rows(matrix):
in[12]    s = matrix.reshape(-1,1)
in[13]    return s @ s.T

I do not know how to get it to the correct shape in a fast way. Since, yes, the correct 2-d arrays are a part of the matrix on the diagonal, I have to get them together to match the shape of the diagonal 3-d matrix I got. This means I would somehow both have to extract the correct matrices and then turn that into a 3-d array of shape (n_samples,row,row). I do not know how to do that any faster than just a simple loop through all rows of the input matrix.

CodePudding user response:

Use broadcasting:

>>> a[:, None, :] * a[:, :, None]
array([[[ 1,  2,  3],
        [ 2,  4,  6],
        [ 3,  6,  9]],

       [[16, 20, 24],
        [20, 25, 30],
        [24, 30, 36]]])
  • Related