I am trying to implement softmax function but weirdly I am getting two different outputs on MATLAB and on Python:
MATLAB script:
function sm = softmax(Y)
e_y = exp(Y - max(Y))
sm = e_y / sum(e_y)
which computes ten times 0.1
as a column vector (correctly because the sum over each column is 1
) if Y is a 10x200
matrix of 501
Meanwhile this Python script:
import numpy as np
def softmax(y):
e_y = np.exp(y - np.max(y))
return e_y / e_y.sum()
y = np.full((10,200), fill_value=501)
print(softmax(y))
computes, on the same input y
,
[[0.0005 0.0005 0.0005 ... 0.0005 0.0005 0.0005]
[0.0005 0.0005 0.0005 ... 0.0005 0.0005 0.0005]
[0.0005 0.0005 0.0005 ... 0.0005 0.0005 0.0005]
...
[0.0005 0.0005 0.0005 ... 0.0005 0.0005 0.0005]
[0.0005 0.0005 0.0005 ... 0.0005 0.0005 0.0005]
[0.0005 0.0005 0.0005 ... 0.0005 0.0005 0.0005]]
which is wrong since sum of each column is not 1
but it it 0.005
What am I missing?
CodePudding user response:
What you really intended to do is computing an element-wise division instead of a matrix division: you need to replace /
operator with ./
- according to documentation -- in your MATLAB script which was (wrongly) computing a single column vector or 0.1
s instead of a matrix of 0.1
s, as Cris Luengo noticed in their comment.
Also, in your Python script you are considering the whole matrix max
element, meanwhile you should consider each column, which represents a single prediction, and normalize it. That is:
def softmax(y):
e_y = np.exp(y - np.max(y, axis=0))
return e_y / np.sum(e_y, axis=0) # axis = 0
You will correctly get a (10,200)
shaped array/matrix composed as follows:
array([[0.1, 0.1, 0.1, ..., 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, ..., 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, ..., 0.1, 0.1, 0.1],
...,
[0.1, 0.1, 0.1, ..., 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, ..., 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, ..., 0.1, 0.1, 0.1]])
Further demo:
# Allocate a (3,4) matrix composed of 0..11 elements
y = np.asmatrix(np.arange(12)).reshape(3,4)
print(softmax(y))
you'll get:
matrix([[3.29320439e-04, 3.29320439e-04, 3.29320439e-04, 3.29320439e-04],
[1.79802867e-02, 1.79802867e-02, 1.79802867e-02, 1.79802867e-02],
[9.81690393e-01, 9.81690393e-01, 9.81690393e-01, 9.81690393e-01]])