Home > Enterprise >  ValueError in numba vectorize for accumulate
ValueError in numba vectorize for accumulate

Time:09-23

I'm trying to write a ufunc with Numba. I read this and incorporated into into my code. So, my basic code which runs is

import numpy as np
arr = np.arange(15).reshape((3,5))
def myadd(x, y):
  return x y
myadd = np.frompyfunc(myadd, 2, 1)
print(myadd.accumulate(arr,dtype=object,axis=1).astype(int))

Now if I use Numba vectorize

import numpy as np
import numba as nb
arr = np.arange(15).reshape((3,5))
@nb.vectorize
def myadd(x, y):
  return x y
print(myadd.accumulate(arr,dtype=object,axis=1).astype(int))

I get the error

Traceback (most recent call last):

  File "<ipython-input-11-ea9f981e42b2>", line 7, in <module>
    print(myadd.accumulate(arr,dtype=object,axis=1).astype(int))

ValueError: could not find a matching type for myadd.accumulate, requested type has type code 'O'

What is the workaround for this? I'm using Numba 0.54.0, Numpy 1.20.3, under Spyder in Anaconda with Python 3.8.10 on Windows 10.

CodePudding user response:

I figured adding signature and removing dtype will resolve the error.

I dont have exact answer as to why adding signature works. But hope this will get things running and help you find answer. (My opinion, arr is of int type and myadd wants float....not sure)

import numpy as np
import numba as nb
arr = np.arange(15).reshape((3,5))
@nb.vectorize([(nb.int64, nb.int64)])
def myadd(x, y):
  return x y
print(myadd.accumulate(arr, axis=1).astype(int))
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14]])
  • Related