Numba failure with np.mean

For some reason, numba fails when I add in a axis argument to np.mean. For instance, this gives an error –

import numpy as np
from numba import jit
@jit(nopython=True)
def num_prac(a):
    return np.mean(a,-1)

b=np.array([[1,2,3,4,5],[1,2,3,4,5]])
print(num_prac(b))

TypingError: Invalid use of Function(<function mean at 0x000002949B28E1E0>) with argument(s) of type(s): (array(int32, 2d, C), Literal[int](1))
 * parameterized
In definition 0:
    AssertionError: 
    raised from C:ProgramDataAnaconda3libsite-packagesnumbatypingarraydecl.py:649
In definition 1:
    AssertionError: 
    raised from C:ProgramDataAnaconda3libsite-packagesnumbatypingarraydecl.py:649
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<function mean at 0x000002949B28E1E0>)
[2] During: typing of call at C:/Users/U374235/test.py (11)

However, this works perfectly –

import numpy as np
from numba import jit
@jit(nopython=True)
def num_prac(a):
    return np.mean(a)

b=np.array([[1,2,3,4,5],[1,2,3,4,5]])
print(num_prac(b))

Answer

numba doesn’t support arguments for np.mean() (including “axis” argument which is not included).

You can do the following to have similar result:

import numpy as np
from numba import jit, prange

a = np.array([[0, 1, 2], [3, 4, 5]])
res_numpy = np.mean(a, -1)

@jit(parallel=True)
def mean_numba(a):

    res = []
    for i in prange(a.shape[0]):
        res.append(a[i, :].mean())

    return np.array(res)

np.array_equal(res_numpy, mean_numba(a))

Related github issue: https://github.com/numba/numba/issues/1269

Leave a Reply

Your email address will not be published. Required fields are marked *