Home > other >  I don't understand how the second bracket works
I don't understand how the second bracket works

Time:04-18

This piece of code is for plotting a series of data by coloring by the classes they belong to. X_train is an array (115,2) and Y_train is another array (115,) with their respective scope values. My question is what does [Y_train == i] do exactly?

colors = ["red", "greenyellow", "blue"]
for i in range(len(colors)):
    xs = X_train[:, 0][Y_train == i]
    ys = X_train[:,1][Y_train == i]
    plt.scatter(xs, ys, c = colors[i])

plt.legend(iris.target_names)
plt.xlabel("Sepal length")
plt.ylabel("Sepal width") 

CodePudding user response:

Boolean values in python are just subclasses of integers.

Y_train == i just evaluates into either False or True, which is then used to access either index 0 or 1 respectively.

>>> a = ['this string is at index 0', 'this string is at index 1']
>>> a[True]
'this string is at index 1'
>>> a[False]
'this string is at index 0'
>>> a[1   2 == 3]  # true
'this string is at index 1'

CodePudding user response:

When you perform a comparison in NumPy, such as Y_train == i the result is a boolean mask, that is an array containing True for every entry in the array when the value matches i, and False for every other value.

So, for example, with a simple array like:

y = np.array([1,2,1,3])

If you look at y == 1 the result is:

array([True, False, True, False])

In the case of a simple array:

x = np.array([[1,2],[3,4],[5,6],[7,8]])
x
Out[10]: 
    array([[1, 2],
           [3, 4],
           [5, 6],
           [7, 8]])

You are first slicing one column at a time, for example:

x[:, 0]
Out[11]: array([1, 3, 5, 7])

And then applying the boolean mask, which returns only the values in that column that also have True in the y == 1 boolean mask:

x[:, 0][y == 1]
Out[14]: array([1, 5])

So the above has exactly the same result as:

x[:, 0][[True, False, True, False]]
Out[16]: array([1, 5])
  • Related