Home > database >  Loop in Matplotlib in a single subplot
Loop in Matplotlib in a single subplot

Time:04-16

I'm currently working on Principal Components Analysis and I would like to plot a correlation circle, 3 in my case because I've got 3 PCA.

The code is ok but I would like to show the result in a subplot (1 row, 3 columns) because now I have 3 consecutive figures.

When I try to initilaze the Matplotlib figure with fig, ax = plt.subplots(1,3) it returns 3 "grid" of subplots of 1x3 with 1 circle inside each. So instead of that I would like 1 "grid" with my 3 circle in 3 columns on the same row.

My code :

pcs = pca.components_
def display_circles(pcs, n_comp, pca, axis_ranks, labels=None, label_rotation=0, lims=None):
    # Initialise the matplotlib figure
    fig, ax = plt.subplots(1,3)

    # For each factorial plane
    for d1, d2 in axis_ranks: 
        if d2 < n_comp:

            # Determine the limits of the chart
            if lims is not None :
                xmin, xmax, ymin, ymax = lims
            elif pcs.shape[1] < 30 :
                xmin, xmax, ymin, ymax = -1, 1, -1, 1
            else :
                xmin, xmax, ymin, ymax = min(pcs[d1,:]), max(pcs[d1,:]), min(pcs[d2,:]), max(pcs[d2,:])

            # Add arrows
            plt.quiver(np.zeros(pcs.shape[1]), np.zeros(pcs.shape[1]), pcs[d1,:], pcs[d2,:], angles='xy', scale_units='xy', scale=1, color="grey")
            
            # Display variable names
            if labels is not None:  
                for i,(x, y) in enumerate(pcs[[d1,d2]].T):
                    if x >= xmin and x <= xmax and y >= ymin and y <= ymax :
                        plt.text(x, y, labels[i], fontsize='10', ha='center', va='center', rotation=label_rotation, color="blue", alpha=0.5)
            
            # Display circle
            circle = plt.Circle((0,0), 1, facecolor='none', edgecolor='b')
            plt.gca().add_artist(circle)

            # Label the axes, with the percentage of variance explained
            plt.xlabel('PC{} ({}%)'.format(d1 1, round(100*pca.explained_variance_ratio_[d1],1)))
            plt.ylabel('PC{} ({}%)'.format(d2 1, round(100*pca.explained_variance_ratio_[d2],1)))

            plt.title("Correlation Circle (PC{} and PC{})".format(d1 1, d2 1))
            plt.show(block=False)

display_circles(pcs, num_components, pca, [(0,1), (1,2), (0,2)], labels = header) 

Thanks for the help !!

CodePudding user response:

Sadly, you didn't provide data, so this answer is only going to explain what you can do to achieve your goal.

With fig, ax = plt.subplots(1,3) you created 3 different axis that are stored inside ax. So, ax[0] refers to the first axis on the left, ax[1] refers to the center axis, ax[2] refers to the axis on the right. We use them to target the correct axis.

Since you have 3 PCA, we need an index to target the correct axis. So, change for d1, d2 in axis_ranks: to for k, (d1, d2) in enumerate(axis_ranks):. Now we can use the index k to target the correct axis.

Next, you need to replace plt. with ax[i].. But we need to be careful though, as some methods are going to have a different name:

pcs = pca.components_
def display_circles(pcs, n_comp, pca, axis_ranks, labels=None, label_rotation=0, lims=None):
    # Initialise the matplotlib figure
    fig, ax = plt.subplots(1,3)

    # For each factorial plane
    for k, (d1, d2) in enumerate(axis_ranks): 
        if d2 < n_comp:

            # Determine the limits of the chart
            if lims is not None :
                xmin, xmax, ymin, ymax = lims
            elif pcs.shape[1] < 30 :
                xmin, xmax, ymin, ymax = -1, 1, -1, 1
            else :
                xmin, xmax, ymin, ymax = min(pcs[d1,:]), max(pcs[d1,:]), min(pcs[d2,:]), max(pcs[d2,:])

            # Add arrows
            ax[k].quiver(np.zeros(pcs.shape[1]), np.zeros(pcs.shape[1]), pcs[d1,:], pcs[d2,:], angles='xy', scale_units='xy', scale=1, color="grey")
            
            # Display variable names
            if labels is not None:  
                for i,(x, y) in enumerate(pcs[[d1,d2]].T):
                    if x >= xmin and x <= xmax and y >= ymin and y <= ymax :
                        ax[k].text(x, y, labels[i], fontsize='10', ha='center', va='center', rotation=label_rotation, color="blue", alpha=0.5)
            
            # Display circle
            circle = plt.Circle((0,0), 1, facecolor='none', edgecolor='b')
            ax[k].add_artist(circle)

            # Label the axes, with the percentage of variance explained
            ax[k].set_xlabel('PC{} ({}%)'.format(d1 1, round(100*pca.explained_variance_ratio_[d1],1)))
            ax[k].set_ylabel('PC{} ({}%)'.format(d2 1, round(100*pca.explained_variance_ratio_[d2],1)))

            ax[k].set_title("Correlation Circle (PC{} and PC{})".format(d1 1, d2 1))

display_circles(pcs, num_components, pca, [(0,1), (1,2), (0,2)], labels = header)
  • Related