I would like to vectorize a function which applies some simple filtering logic to a numpy array:
import numpy
def func(arr, a, b):
sub = arr[arr[:,a] > b]
mean = numpy.mean(sub, axis=0)
return mean
a = numpy.array([0,1,2])
b = numpy.array([0,2,0])
arr = numpy.array([[0,2,3],[4,4,0]])
out = map(func, arr, a, b)
print(list(out))
However, this raises an error:
sub = arr[arr[:,a] > b] IndexError: too many indices for array
This seems to break because its not properly utilizing map
and iterating over the lists a
and b
. How can you use a numpy array and lists as inputs to a map
function?
Here, my expected output should be an array of arrays (3 arrays long) that are each the mean of the array--note that the actual logic I want to perform is more complicated than just calculating means, but this should get the point across.
CodePudding user response:
Check what map
is feeding your func
:
In [31]: def func(arr, a, b):
...: print(arr,a,b)
...: return 1
...:
...:
In [32]: a = numpy.array([0,1,2])
...: b = numpy.array([0,2,0])
...: arr = numpy.array([[0,2,3],[4,4,0]])
...:
...: out = map(func, arr, a, b)
...: list(out)
[0 2 3] 0 0
[4 4 0] 1 2
Out[32]: [1, 1]
transpose arr
so it's (3,2)
...: out = map(func, arr.T, a, b)
...: list(out)
[0 4] 0 0
[2 4] 1 2
[3 0] 2 0
Out[33]: [1, 1, 1]
It's iterating over all arguments, not just a
and b
. And using the shortest.
It's the same sort of iteration that we get from zip
:
In [34]: list(zip(arr,a,b))
Out[34]: [(array([0, 2, 3]), 0, 0), (array([4, 4, 0]), 1, 2)]
In [35]: list(zip(arr.T,a,b))
Out[35]: [(array([0, 4]), 0, 0), (array([2, 4]), 1, 2), (array([3, 0]), 2, 0)]
Leave arr
outside of the map, taking it as a global:
In [36]: def func(a, b):
...: sub = arr[arr[:,a] > b]
...: mean = numpy.mean(sub, axis=0)
...: return mean
...:
In [37]: list(map(func,a,b))
Out[37]: [array([4., 4., 0.]), array([4., 4., 0.]), array([0., 2., 3.])]
map
docs:
map(func, *iterables) --> map object
Make an iterator that computes the function using arguments from
each of the iterables. Stops when the shortest iterable is exhausted.
Let's add a print to get a clearer idea of what your func is doing:
In [56]: def func(a, b):
...: sub = arr[arr[:,a] > b]
...: print(a,b,sub)
...: mean = numpy.mean(sub, axis=0)
...: return mean
...:
In [57]: list(map(func,a,b))
0 0 [[4 4 0]]
1 2 [[4 4 0]]
2 0 [[0 2 3]]
Out[57]: [array([4., 4., 0.]), array([4., 4., 0.]), array([0., 2., 3.])]
With that indexing sub
is a (1,3) array, so the mean
does do anything interesting
Drop the axis
, it's more interesting:
In [59]: def func(a, b):
...: sub = arr[arr[:,a] > b]
...: print(a,b,sub)
...: mean = numpy.mean(sub)
...: return mean
...:
...:
In [60]: list(map(func,a,b))
0 0 [[4 4 0]]
1 2 [[4 4 0]]
2 0 [[0 2 3]]
Out[60]: [2.6666666666666665, 2.6666666666666665, 1.6666666666666667]
This indexing of arr
selects whole rows, in this case the 2nd 2 times, and the 1st once.