Home > front end >  How can i test numba vectorized functions with pytest
How can i test numba vectorized functions with pytest

Time:11-18

I'm currently writing a couple of functions, both of them are optimised using numba (one with @guvectorize and one with @vectorize. I also wrote some tests for both functions, but when i run pytest --cov --cov-report term-missing I get that the missing lines corresponds to the optimised functions.

Is this a problem of how pytest runs the tests on the functions or is it due to some other (mine) problem?

The simplest of the two functions is:

@vectorize(["float64(float64, float64)", "float32(float32, float32)"], nopython=True)
def binarize_mask(mask_data, threshold):
    """Binarize the mask array based on a threshold.

    :param mask_data: Mask array.
    :param threshold: Threshold to apply to the mask.
    """
    # Binarize the mask array
    return 1 if mask_data >= threshold else 0

which i test with the following tests:

  1. For a single value:
def test_binarize_mask_return_value():
    threshold = np.float32(0.5)
    assert dl.binarize_mask(np.float32(0.3), threshold) == 0
    assert dl.binarize_mask(np.float32(0.7), threshold) == 1
  1. For an array:
def test_binarize_mask_float32():
    test_data = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=np.float32)
    test_mask = np.array([0, 0, 1, 1, 1, 1, 1, 1, 1, 1], dtype=np.float32)
    binarized = dl.binarize_mask(test_data, 3.0)
    assert binarized.dtype == np.float64
    assert binarized.shape == test_mask.shape
    assert np.all(binarized == test_mask)

CodePudding user response:

As long as the code is compiled, coverage.py can no longer measure coverage on your code. You can find some issues about this.

You may simply hide untested code under the carpet by excluding code from coverage.py.

But I understand you are serious about your code and you really want to check your algorithms. Then you can run your tests twice. One to check your code and another one only for coverage, by setting the environment variable NUMBA_DISABLE_JIT=1, as described here.

  • Related