I have trouble understanding how mplcursors
cursors work. Let me give an example.
import pandas as pd
import matplotlib.pyplot as plt
import mplcursors
%matplotlib qt5
def random_data(rows_count):
data = []
for i in range(rows_count):
row = {}
row["x"] = np.random.uniform()
row["y"] = np.random.uniform()
if (i%2 == 0):
row["type"] = "sith"
row["color"] = "red"
else:
row["type"] = "jedi"
row["color"] = "blue"
data.append(row)
return pd.DataFrame(data)
data_df = random_data(30)
fig, ax = plt.subplots(figsize=(8,8))
ax = plt.gca()
types = ["jedi","sith"]
for scat_type in types:
local_data_df = data_df.loc[data_df["type"] == scat_type]
scat = ax.scatter(local_data_df["x"],
local_data_df["y"],
c=local_data_df["color"],
label=scat_type)
cursor = mplcursors.cursor(scat, hover=mplcursors.HoverMode.Transient)
@cursor.connect("add")
def on_add(sel):
annotation = (local_data_df.iloc[sel.index]["type"]
"\n" str(local_data_df.iloc[sel.index]["x"])
"\n" str(local_data_df.iloc[sel.index]["y"]))
sel.annotation.set(text=annotation)
ax.legend()
plt.title("a battle of Force users")
plt.xlabel("x")
plt.ylabel("y")
plt.xlim(-1, 2)
plt.ylim(-1, 2)
ax.set_aspect('equal', adjustable='box')
plt.show()
This code is supposed to generate a DataFrame such that each row has random properties x
, y
, a type
which is jedi
or sith
, and a color
which is blue
or red
, depending on if the row is a jedi
or a sith
, then scatterplot the jedis in their color, attach to them a cursor, and then scatterplot the siths in their color, and attach to them another cursor, and display a legend box telling the reader that blue points correspond to jedi
rows and red ones to sith
ones.
However, when hovering points, the annotations say that all the points are sith
and the coordinates do not look good.
I would like to understand why the code does not do what I would like it to do.
Just to clarify: I call .scatter()
for each type (jedi
or sith
) and then try to attach a cursor to each of the plots because I have tried calling scatter
on the whole data_df
, but then the .legend()
does not display what I want.
I hope that the answer you give me will be enough for me to be able to write a code that displays the jedi
and the sith
points, shows the right annotations and the right legend box.
CodePudding user response:
There are a lot of strange things going on.
One of the confusions is that having the variable local_data_df
inside a for
loop would create a variable that would only be local to one cycle of the loop. Instead, it is just a global variable that gets overridden for each cycle. Similarly, defining the function on_add
inside the for
loop doesn't make it local. Also on_add
will be global and overridden by each cycle of the for
loop.
Another confusion is that the connected function would have access to local variables from another function or loop. Instead, such local variables get inaccessible once the function or loop has finished.
Further, not that sel.index
will not be the index into the dataframe, but into the points of the scatter plot. You can reset the index of the "local df" to have it similar to the way sel.index
is ordered.
To mimic your local variable, you can add extra data to the scat
object. E.g. scat.my_data = local_df
will add that variable to the global object that contains the scatter element (the PathCollection
that contains all information matplotlib needs to represent the scatter points). Although the variable scat
gets overridden, there is one PathCollection
for each of the calls to ax.scatter
. (You can also access these via ax.collections
).
Here is a rewrite of your code, trying to stay as close as possible to the original:
import pandas as pd
import matplotlib.pyplot as plt
import mplcursors
def random_data(rows_count):
df = pd.DataFrame({'x': np.random.uniform(0, 1, rows_count),
'y': np.random.uniform(0, 1, rows_count),
'type': np.random.choice(['sith', 'jedi'], rows_count)})
df['color'] = df['type'].replace({'sith': 'red', 'jedi': 'blue'})
return df
def on_add(sel):
local_data_df = sel.artist.my_data
annotation = (local_data_df.iloc[sel.index]["type"]
"\n" str(local_data_df.iloc[sel.index]["x"])
"\n" str(local_data_df.iloc[sel.index]["y"]))
sel.annotation.set(text=annotation)
data_df = random_data(30)
fig, ax = plt.subplots(figsize=(8, 8))
types = ["jedi", "sith"]
for scat_type in types:
local_data_df = data_df.loc[data_df["type"] == scat_type].reset_index() # resetting the index is necessary
scat = ax.scatter(local_data_df["x"],
local_data_df["y"],
c=local_data_df["color"],
label=scat_type)
scat.my_data = local_data_df # store the data into the scat object
cursor = mplcursors.cursor(scat, hover=mplcursors.HoverMode.Transient)
cursor.connect("add", on_add)
ax.legend()
ax.set_title("a battle of Force users")
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_xlim(-1, 2)
ax.set_ylim(-1, 2)
ax.set_aspect('equal', adjustable='box')
plt.show()