I create a function that take x, y, batch size as input and yield mini batch as output with cython to sped up the process.
import numpy as np
cimport cython
cimport numpy as np
ctypedef np.float64_t DTYPE_t
@cython.boundscheck(False)
def create_mini_batches(np.ndarray[DTYPE_t, ndim=2] X, np.ndarray[DTYPE_t, ndim=2] y, int batch_size):
cdef int m
cdef double num_of_batch
cdef np.ndarray[DTYPE_t, ndim=2] shuffle_X
cdef np.ndarray[DTYPE_t, ndim=2] shuffle_y
cdef int permutation
X, y = X.T, y.T
m = X.shape[0]
num_of_batch = m // batch_size
permutation = list(np.random.permutation(m))
shuffle_X = X[permutation, :]
shuffle_y = y[permutation, :]
for t in range(num_of_batch):
mini_x = shuffle_X[t * batch_size: (t 1) * batch_size, :]
mini_y = shuffle_y[t * batch_size: (t 1) * batch_size, :]
yield (mini_x.T, mini_y.T)
if m % batch_size != 0:
mini_x = shuffle_X[m // batch_size * batch_size: , :]
mini_y = shuffle_y[m // batch_size * batch_size: , :]
yield (mini_x.T, mini_y.T)
When I compile the program with this code python setup.py build_ext --inplace
the following error showed up.
@cython.boundscheck(False)
def create_mini_batches(np.ndarray\[DTYPE_t, ndim=2\] X, np.ndarray\[DTYPE_t, ndim=2\] y, int batch_size):
^
test.pyx:8:24: Buffer types only allowed as function local variables
Can someone help me how to solved the error and why it is a error?
CodePudding user response:
It's a sightly confusing error message in this case but you're getting it because it's a generator rather than a function. This means that Cython has to create an internal data structure to hold the generator state while it works.
Typed Numpy array variables (e.g. np.ndarray[DTYPE_t, ndim=2]
) were implemented in a way where it's very hard to handle their reference counting correctly. Therefore Cython can only handle them as variables in a regular function. It cannot store them in a class, and thus cannot use them in a generator.
To solve it your either need to drop the typing, or you should switch to the more recent typed memoryviews which were designed better so don't have this limitation.