So I'm trying to plot two plots on the same graph, one Y axis on either sides sharing the same X axis. I've done this earlier and hence knew how to do it (or I though so). Anyways now I was trying to implement it under a function since I need to make a lot of plots and hence wanted a more modular solution. Now when trying to run the same thing under a fucntion it throws error 'numpy.ndarray' object has no attribute 'twinx'. This is because for some reason ax1 is being shown of class numpy.ndarray which should actually be matplotlib.axes._subplots.AxesSubplot.
Please Help.
def pumped_up_plotting(data, colname1, colname2):
fig, ax1 = plt.subplots(13, 4, figsize=(5*4, 5*13))
print("Look here lil bitch: ", type(ax1))
ax2 = ax1.twinx()
plt.subplots_adjust(hspace=0.3)
for i in range(0, 13):
for j in range(0,4):
id_ = profiles[i*4 j]
samp = data[data["profile_id"] == id_]
ax1[i, j].plot(samp["time"], samp[colname1], label=colname1, color="blue")
ax2[i, j].plot(samp["time"], samp[colname2], label=colname2, color="yellow")
ax1[i, j].set_xlabel("Time")
lines_1, labels_1 = ax1.get_legend_handles_labels()
lines_2, labels_2 = ax2.get_legend_handles_labels()
lines = lines_1 lines_2
labels = labels_1 labels_2
ax1[i, j].legend(lines, labels, loc=0)
pumped_up_plotting(df, "rotor", "motor_work")
CodePudding user response:
ax1
is a numpy array because you are creating a grid of 13 by 4 axes instead of a single axes. You do this by supplying 13 and 4 as the first two arguments to plt.subplots
. I'm not sure what you intended those numbers to do, but if you delete them, it should work.
As of now, ax1 is a numpy array with 13 columns an 4 rows containing the individual axis objects.
I'll try to explain what went wrong. To simplify, I'm gonna use a 2 x 2 grid instead of 13 x 4. So fig, ax1 = plt.subplots(2, 2)
.
Then, ax1
will look like this:
array([[<AxesSubplot:>, <AxesSubplot:>],
[<AxesSubplot:>, <AxesSubplot:>]], dtype=object)
If you try to call ax1.twinx()
, it won't work because ax1
is actually not an axis, but the array containing all your 4 axes of the grid.
So what you would have to call if you wanted to create a twin axis of the first axis, would be ax1[0,0].twinx()
. Since you want to do it for every axis and not just the first one, you can do it inside a nested loop where you loop over the rows and columns of the numpy array. Since you are already doing this, you can justput that line inside your already existing loop.
This looks like that line.
ax2 = ax1[i, j].twinx()
Here, we are taking the individual axis object from the numpy array by indexing (as you were already doing before) and calling twinx
on it. This returns a twin axis which we are saving as ax2
. Note that this is kind of confusing, since ax2
is a single axis object while ax1
is an array containg axis objects. I personally would rename ax1
to axs
so it's clear this variable contains multiple axes.
Because ax2
is already a single axis object, we can call the plotting functions directly on it, and don't have to index it.
def pumped_up_plotting(data, colname1, colname2):
fig, ax1 = plt.subplots(13, 4, figsize=(5*4, 5*13))
print("Type of ax1 ", type(ax1))
plt.subplots_adjust(hspace=0.3)
for i in range(0, 13):
for j in range(0,4):
ax2 = ax1[i, j].twinx()
id_ = profiles[i*4 j]
samp = data[data["profile_id"] == id_]
ax1[i, j].plot(samp["time"], samp[colname1], label=colname1, color="blue")
ax2.plot(samp["time"], samp[colname2], label=colname2, color="yellow")
ax1[i, j].set_xlabel("Time")
lines_1, labels_1 = ax1.get_legend_handles_labels()
lines_2, labels_2 = ax2.get_legend_handles_labels()
lines = lines_1 lines_2
labels = labels_1 labels_2
ax1[i, j].legend(lines, labels, loc=0)
pumped_up_plotting(df, "rotor", "motor_work")
My way of doing this more clearly would be:
def pumped_up_plotting(data, colname1, colname2):
fig, axs = plt.subplots(13, 4, figsize=(5*4, 5*13))
plt.subplots_adjust(hspace=0.3)
for i, row in enumerate(axs):
for j, ax in enumerate(row):
ax2 = ax.twinx()
id_ = profiles[i*4 j]
samp = data[data["profile_id"] == id_]
ax.plot(samp["time"], samp[colname1], label=colname1, color="blue")
ax2.plot(samp["time"], samp[colname2], label=colname2, color="yellow")
ax.set_xlabel("Time")
lines_1, labels_1 = ax.get_legend_handles_labels()
lines_2, labels_2 = ax2.get_legend_handles_labels()
lines = lines_1 lines_2
labels = labels_1 labels_2
ax.legend(lines, labels, loc=0)
pumped_up_plotting(df, "rotor", "motor_work")