Home > OS >  How does this python normalization code work?
How does this python normalization code work?

Time:10-22

  cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

In learning about normalization for image recognition, I have seen many people use this code. I know this sentence is used to normalize the confusion matrix so that it contains only numbers between 0 and 1. So that the percentage of correctly classified samples is read from the matrix. I'm not very good at math, but I'd like to know exactly how this sentence works. If anyone can help me, I'd appreciate it!

CodePudding user response:

It finds a sum along an axis (axis 1) and then does broadcasted division along that axis by the corresponding value of the sum.

So suppose you had:

>>> arr = np.arange(4*5).reshape(4, 5)
>>> arr
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19]])

So first, it sums along the axis:

>>> arr.sum(1)
array([10, 35, 60, 85])

Note, you can't broadcast these two arrays with the current shape:

>>> arr / arr.sum(1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: operands could not be broadcast together with shapes (4,5) (4,)

The trailing axis has to be 1, so you add a new axis, with resulting shape (4, 1):

>>> arr.sum(1)[:, np.newaxis]
array([[10],
       [35],
       [60],
       [85]])
>>> arr.sum(1)[:, np.newaxis].shape
(4, 1)

So now, the broadcasting division works:

>>> arr
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19]])
>>> arr.sum(1)[:, np.newaxis]
array([[10],
       [35],
       [60],
       [85]])
>>> arr / arr.sum(1)[:, np.newaxis]
array([[0.        , 0.1       , 0.2       , 0.3       , 0.4       ],
       [0.14285714, 0.17142857, 0.2       , 0.22857143, 0.25714286],
       [0.16666667, 0.18333333, 0.2       , 0.21666667, 0.23333333],
       [0.17647059, 0.18823529, 0.2       , 0.21176471, 0.22352941]])

Read more about broadcasting in the numpy docs

  • Related