Home > Mobile >  Struggling with numpy libs where()
Struggling with numpy libs where()

Time:10-25

I somehow got mixed up with primitive AI and came across this code which I am having hard time understanding. I read some site but none seem to have answer I am looking for. :( Could anyone explain np.where() function in this scenario? It occured to me that this line of code makes child_pos an empty 2d array if curr_node.get_curr_child() equals 0. But I am not sure... Glad for every response.

The code in question is:

child_pos = np.where(np.asarray(curr_node.get_curr_child()) == 0)[0][0]

CodePudding user response:

Disregarding your code, np.where returns the positions of the values you are searching for in the where statement.

For example:

Let's assume

matrix = array([[1., 1., 1.],
       [1., 0., 1.],
       [1., 1., 0.]])

If we were to run np.where(matrix == 0) what we would get is

(array([1, 2], dtype=int64), 
 array([1, 2], dtype=int64))

Which basically gives you the row/column positions of the value 0 in the original 2-dimensional array. The first array represents the row positions and the second array represents the column positions.

This logic extends to higher/lower dimensions as well.

Returning to your code, you turn the result of get_curr_child into an np array and then you fetch the first value from the first dimension of the np.where result.

  • Related