Home > Mobile >  Why does numpy.dot give incorrect results?
Why does numpy.dot give incorrect results?

Time:11-09

This code:

a = np.array([10], dtype=np.int8)
b = np.array([2],  dtype=np.int8)
print(np.dot(a, b))

a = np.array([10], dtype=np.int8)
b = np.array([5],  dtype=np.int8)
print(np.dot(a, b))

a = np.array([10], dtype=np.int8)
b = np.array([20], dtype=np.int8)
print(np.dot(a, b))

produces the following output:

20
50
-56

It appears that np.dot will attempt to return the result in the same data type object even if it can't fit. Surely this is a bug? Why doesn't it throw an exception?

CodePudding user response:

It can depends on the dtype used. In fact, if I change the dtype:

a = np.array([10])
b = np.array([2])
print(np.dot(a, b))

a = np.array([10])
b = np.array([5])
print(np.dot(a, b))

a = np.array([10])
b = np.array([20])
print(np.dot(a, b))

The Output is:

20
50
200

CodePudding user response:

It happens because type numpy.int8 supports integer numbers in the interval [-128, 128), since with just 8 bit you can encode only 256 different numbers. In order to better understand the behavior of casting an integer n not in this range in a numpy.int8, let's define n = np.arange(-2**8, 2**8) and let's see how this array is converted in numpy.int8:

plt.plot(n, np.int8(n))
plt.xticks([0,  32,  64, 128, 256, -32,  -64, -128, -256])
plt.yticks([0,  32,  64, 128, -32,  -64, -128])
plt.xlabel("n")
plt.ylabel("int8(n)")

plt.show()

plot result

As you can see, the conversion function is a periodic function of period 256; out of range [-128, 128) you'll have n!=numpy.int8(n); in general you can say numpy.int8(n) == ((n-128)%6)-128), indeed ((200-128)%6)-128 == -56

PS. If you only need positive integers you can use the type numpy.uint8 to encode numbers in the interval [0,256)

CodePudding user response:

This is true for multiplication and addition.

In [89]: np.array([128], 'int8')*2
Out[89]: array([0], dtype=int8)

In [90]: np.array([127], 'int8')*2
Out[90]: array([-2], dtype=int8)        # same int8 dtype

But if I work with an element of the array, a np.int8 object, the result is promoted.

In [91]: np.array([127], 'int8')[0]*2
Out[91]: 254    
In [92]: type(_)
Out[92]: numpy.int32

I think, though can't offhand produce, there are cases where this kind of thing raises an error.

This has been discussed in other SO, for multiplication, if not for np.dot.

This is a overflow question for 'uint8' dtypes, and github issues link:

Allow overflow for numpy types

https://github.com/numpy/numpy/issues/8987 "BUG: Integer overflow warning applies to scalars but not arrays"

  • Related