Home > database >  What is incorrect about the way I'm dynamically setting the number of subplots in matplotlib
What is incorrect about the way I'm dynamically setting the number of subplots in matplotlib

Time:03-06

As part of a larger module, I'd like to save a series of images using matplotlib. There will only be a small number of images (i.e. 1-3), so I'm trying to dynamically set the number of columns specified in plt.subplot. When I try to do this for three images, everything is fine - I get a row of 3 images:enter image description here

However, when I try to do this with only two images, I get a single image on the right hand side of the subplot: enter image description here

I've extracted out the key method of the original code and created an example (see below) that feeds an image array to that method with exactly the same format as the actual code.

Why am I able to get a row of 3 images, but not of 2 images?

import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg') # suppresses plot

import numpy as np
from PIL import Image, ImageOps

def generate_images(test_input, path_filename):

    plt.figure(figsize=(15, 6))

    # This configuration works - Gives 3 images
    #display_list = [test_input[0], test_input[0], test_input[0]]
    #title = ['Test Image 1', 'Test Image 2', 'Test Image 3']

    # This configuration does not work - Only gives 1 image on right side
    # of subplot
    display_list = [test_input[0], test_input[0]]
    title = ['Test Image 1', 'Test Image 2']

    for i in range(len(title)):

        # Here is where I tried to dynamically create my subplot dimensions
        plt.subplot(1, len(title), i   1)
        plt.title(title[i])


        # Getting the pixel values in the [0, 1] range to plot.
        plt.imshow(display_list[i] * 0.5   0.5, cmap=plt.get_cmap('gray'))
        plt.axis('off')
        plt.tight_layout()

    plt.savefig(path_filename, dpi=200)
    plt.close()

    print()
    print("Image Shape:",test_input.shape)
    print("len(title):",len(title))


# Get Sample Image 
file_path = <path to input image>
im = Image.open(file_path)

# Rescale and change from RGB to grayscale 
im2 = ImageOps.grayscale(im)
im2 = im2.resize((256,256))
num_array = np.asarray(im2)

# Convert to an array of images
num_array = num_array[np.newaxis,:,:,np.newaxis]

out_path = <path and filename of output image>
generate_images(num_array,out_path)

CodePudding user response:

The two images are plotted, for some reason, the problem arises when you want to use: plt.tight_layout().

try to call this function after all axes have been added

enter image description here

Observe how I put the call to the function after the for loop, this should solve your problem

enter image description here

You can get more information about this function in the following link: matplotlib.pyplot.tight_layout

  • Related