I am trying to use np.where() function with nested lists. I would like to find an index with a given condition of the first layer of the nested list.
For example, if I have the following code
arr = [[1,1], [2,2],[3,3]]
a = np.where(arr == [2,2])
then ideally I would like code to return 'a' as 1. Since [2,2] is in index 1 of the nested list.
However, I am just getting a empty array back as a result.
Of course, I can make it work easily by implementing external for loop such as
for n in range(len(arr)):
if arr[n] == [2,2]:
a = n
but I would like to implement this simply within the function np.where(write the entire code here).
Is there a way to do this?
CodePudding user response:
The best solution is that mentioned by @Michael Szczesny, but using np.where
you can do this too:
a = np.where(np.array(arr) == [2, 2])[0]
resulted_ind = np.where(np.bincount(a) == 2)[0] # --> [1]
CodePudding user response:
Well you can write your own function to do so:
You'll need to
- Find every line equal to what you looking for
- Get indices of found rows (You can use
where
):
numpy
compression
You can use compression operator to see if each line satisfies the condition. Such as:
np_arr = np.array(
[1, 2, 3, 4, 5]
)
print(np_arr < 3)
This will return a boolean where every element is True
or False
where the condition is satisfied:
[ True True False False False]
For a 2D array you'll get a 2D boolean array:
to_find = np.array([2, 2])
np_arr = np.array(
[
[1, 1],
[2, 2],
[3, 3],
[2, 2]
]
)
print(np_arr == to_find)
The result is:
[[False False]
[ True True]
[False False]
[ True True]]
Now we are looking for lines with all
True
values. So we can use all
method of ndarray
. And we will provide the axis we are looking to look to all. X, Y or Both. We want to look to x axis:
to_find = np.array([2, 2])
np_arr = np.array(
[
[1, 1],
[2, 2],
[3, 3],
[2, 2]
]
)
print((np_arr == to_find).all(axis=1))
The result is:
[False True False True]
Get indices of True
s
At the end you are looking for indices where the values are True
:
np.where((np_arr == to_find).all(axis=1))
The result would be:
(array([1, 3]),)
CodePudding user response:
numpy
runs in Python, so you can use both the basic Python lists and numpy
arrays (which are more like MATLAB matrices)
A list of lists:
In [43]: alist = [[1,1], [2,2],[3,3]]
A list has an index
method, which tests against each element of the list (elements here are 2 element lists):
In [44]: alist.index([2,2])
Out[44]: 1
In [45]: alist.index([2,3])
Traceback (most recent call last):
Input In [45] in <cell line: 1>
alist.index([2,3])
ValueError: [2, 3] is not in list
alist==[2,2]
returns False
, because the list is not the same as the [2,2]
list.
If we make an array from that list:
In [46]: arr = np.array(alist)
In [47]: arr
Out[47]:
array([[1, 1],
[2, 2],
[3, 3]])
we can do an ==
test - but it compares numeric elements.
In [48]: arr == np.array([2,2])
Out[48]:
array([[False, False],
[ True, True],
[False, False]])
Underlying this comparison is the concept of broadcasting
, allow it to compare a (3,2) array with a (2,) (a 2d with a 1d). Here's its trivial, but it can be much more complicated.
To find rows where all values are True, use:
In [50]: (arr == np.array([2,2])).all(axis=1)
Out[50]: array([False, True, False])
and where
finds the True
in that array (the result is a tuple with 1 array):
In [51]: np.where(_)
Out[51]: (array([1]),)
In Octave the equivalent is:
>> arr = [[1,1];[2,2];[3,3]]
arr =
1 1
2 2
3 3
>> all(arr == [2,2],2)
ans =
0
1
0
>> find(all(arr == [2,2],2))
ans = 2