Home > front end >  Matplotlib subplot object being classified/recoqnized as numpy array. Hence not able to use twinx()
Matplotlib subplot object being classified/recoqnized as numpy array. Hence not able to use twinx()

Time:10-12

So I'm trying to plot two plots on the same graph, one Y axis on either sides sharing the same X axis. I've done this earlier and hence knew how to do it (or I though so). Anyways now I was trying to implement it under a function since I need to make a lot of plots and hence wanted a more modular solution. Now when trying to run the same thing under a fucntion it throws error 'numpy.ndarray' object has no attribute 'twinx'. This is because for some reason ax1 is being shown of class numpy.ndarray which should actually be matplotlib.axes._subplots.AxesSubplot.

Please Help.

    def pumped_up_plotting(data, colname1, colname2):
        fig, ax1 = plt.subplots(13, 4, figsize=(5*4, 5*13))
        print("Look here lil bitch: ", type(ax1))
        ax2 = ax1.twinx()
        plt.subplots_adjust(hspace=0.3)

        for i in range(0, 13):
            for j in range(0,4):
                id_ = profiles[i*4 j]
                samp = data[data["profile_id"] == id_]

                ax1[i, j].plot(samp["time"], samp[colname1], label=colname1, color="blue")
                ax2[i, j].plot(samp["time"], samp[colname2], label=colname2, color="yellow")
                ax1[i, j].set_xlabel("Time")
            
                lines_1, labels_1 = ax1.get_legend_handles_labels()
                lines_2, labels_2 = ax2.get_legend_handles_labels()
            
                lines = lines_1   lines_2
                labels = labels_1   labels_2

                ax1[i, j].legend(lines, labels, loc=0)

    pumped_up_plotting(df, "rotor", "motor_work")

Screenshot of error message

CodePudding user response:

ax1 is a numpy array because you are creating a grid of 13 by 4 axes instead of a single axes. You do this by supplying 13 and 4 as the first two arguments to plt.subplots. I'm not sure what you intended those numbers to do, but if you delete them, it should work.

As of now, ax1 is a numpy array with 13 columns an 4 rows containing the individual axis objects.

I'll try to explain what went wrong. To simplify, I'm gonna use a 2 x 2 grid instead of 13 x 4. So fig, ax1 = plt.subplots(2, 2).

Then, ax1 will look like this:

array([[<AxesSubplot:>, <AxesSubplot:>],
       [<AxesSubplot:>, <AxesSubplot:>]], dtype=object)

If you try to call ax1.twinx(), it won't work because ax1 is actually not an axis, but the array containing all your 4 axes of the grid. So what you would have to call if you wanted to create a twin axis of the first axis, would be ax1[0,0].twinx(). Since you want to do it for every axis and not just the first one, you can do it inside a nested loop where you loop over the rows and columns of the numpy array. Since you are already doing this, you can justput that line inside your already existing loop.

This looks like that line. ax2 = ax1[i, j].twinx() Here, we are taking the individual axis object from the numpy array by indexing (as you were already doing before) and calling twinx on it. This returns a twin axis which we are saving as ax2. Note that this is kind of confusing, since ax2 is a single axis object while ax1 is an array containg axis objects. I personally would rename ax1 to axs so it's clear this variable contains multiple axes.

Because ax2 is already a single axis object, we can call the plotting functions directly on it, and don't have to index it.

def pumped_up_plotting(data, colname1, colname2):
    fig, ax1 = plt.subplots(13, 4, figsize=(5*4, 5*13))
    print("Type of ax1 ", type(ax1))
    
    plt.subplots_adjust(hspace=0.3)

    for i in range(0, 13):
        for j in range(0,4):
            ax2 = ax1[i, j].twinx()
            id_ = profiles[i*4 j]
            samp = data[data["profile_id"] == id_]

            ax1[i, j].plot(samp["time"], samp[colname1], label=colname1, color="blue")
            ax2.plot(samp["time"], samp[colname2], label=colname2, color="yellow")
            ax1[i, j].set_xlabel("Time")
        
            lines_1, labels_1 = ax1.get_legend_handles_labels()
            lines_2, labels_2 = ax2.get_legend_handles_labels()
        
            lines = lines_1   lines_2
            labels = labels_1   labels_2

            ax1[i, j].legend(lines, labels, loc=0)

pumped_up_plotting(df, "rotor", "motor_work")

My way of doing this more clearly would be:

def pumped_up_plotting(data, colname1, colname2):
    fig, axs = plt.subplots(13, 4, figsize=(5*4, 5*13))   
    plt.subplots_adjust(hspace=0.3)

    for i, row in enumerate(axs):
        for j, ax in enumerate(row):
            ax2 = ax.twinx()
            id_ = profiles[i*4 j]
            samp = data[data["profile_id"] == id_]

            ax.plot(samp["time"], samp[colname1], label=colname1, color="blue")
            ax2.plot(samp["time"], samp[colname2], label=colname2, color="yellow")
            ax.set_xlabel("Time")
        
            lines_1, labels_1 = ax.get_legend_handles_labels()
            lines_2, labels_2 = ax2.get_legend_handles_labels()
        
            lines = lines_1   lines_2
            labels = labels_1   labels_2

            ax.legend(lines, labels, loc=0)

pumped_up_plotting(df, "rotor", "motor_work")
  • Related