Home > Enterprise >  Cython Buffer types only allowed as function local variables
Cython Buffer types only allowed as function local variables

Time:12-05

I create a function that take x, y, batch size as input and yield mini batch as output with cython to sped up the process.

import numpy as np
cimport cython
cimport numpy as np

ctypedef np.float64_t DTYPE_t

@cython.boundscheck(False)
def create_mini_batches(np.ndarray[DTYPE_t, ndim=2] X, np.ndarray[DTYPE_t, ndim=2] y, int batch_size):
    
    cdef int m
    cdef double num_of_batch
    cdef np.ndarray[DTYPE_t, ndim=2] shuffle_X
    cdef np.ndarray[DTYPE_t, ndim=2] shuffle_y
    cdef int permutation

    X, y = X.T, y.T

    m = X.shape[0]
    num_of_batch = m // batch_size

    permutation = list(np.random.permutation(m))
    shuffle_X = X[permutation, :]
    shuffle_y = y[permutation, :]

    for t in range(num_of_batch):
        mini_x = shuffle_X[t * batch_size: (t   1) * batch_size, :]
        mini_y = shuffle_y[t * batch_size: (t   1) * batch_size, :]
        yield (mini_x.T, mini_y.T)
    
    if m % batch_size != 0:
        mini_x = shuffle_X[m // batch_size * batch_size: , :]
        mini_y = shuffle_y[m // batch_size * batch_size: , :]
        yield (mini_x.T, mini_y.T)

When I compile the program with this code python setup.py build_ext --inplace the following error showed up.

@cython.boundscheck(False)
def create_mini_batches(np.ndarray\[DTYPE_t, ndim=2\] X, np.ndarray\[DTYPE_t, ndim=2\] y, int batch_size):
                       ^

test.pyx:8:24: Buffer types only allowed as function local variables

Can someone help me how to solved the error and why it is a error?

CodePudding user response:

It's a sightly confusing error message in this case but you're getting it because it's a generator rather than a function. This means that Cython has to create an internal data structure to hold the generator state while it works.

Typed Numpy array variables (e.g. np.ndarray[DTYPE_t, ndim=2]) were implemented in a way where it's very hard to handle their reference counting correctly. Therefore Cython can only handle them as variables in a regular function. It cannot store them in a class, and thus cannot use them in a generator.

To solve it your either need to drop the typing, or you should switch to the more recent typed memoryviews which were designed better so don't have this limitation.

  • Related