I'm trying to write a ufunc with Numba. I read this and incorporated into into my code. So, my basic code which runs is
import numpy as np
arr = np.arange(15).reshape((3,5))
def myadd(x, y):
return x y
myadd = np.frompyfunc(myadd, 2, 1)
print(myadd.accumulate(arr,dtype=object,axis=1).astype(int))
Now if I use Numba vectorize
import numpy as np
import numba as nb
arr = np.arange(15).reshape((3,5))
@nb.vectorize
def myadd(x, y):
return x y
print(myadd.accumulate(arr,dtype=object,axis=1).astype(int))
I get the error
Traceback (most recent call last):
File "<ipython-input-11-ea9f981e42b2>", line 7, in <module>
print(myadd.accumulate(arr,dtype=object,axis=1).astype(int))
ValueError: could not find a matching type for myadd.accumulate, requested type has type code 'O'
What is the workaround for this? I'm using Numba 0.54.0, Numpy 1.20.3, under Spyder in Anaconda with Python 3.8.10 on Windows 10.
CodePudding user response:
I figured adding signature and removing dtype will resolve the error.
I dont have exact answer as to why adding signature works. But hope this will get things running and help you find answer. (My opinion, arr is of int type and myadd wants float....not sure)
import numpy as np
import numba as nb
arr = np.arange(15).reshape((3,5))
@nb.vectorize([(nb.int64, nb.int64)])
def myadd(x, y):
return x y
print(myadd.accumulate(arr, axis=1).astype(int))
array([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]])