As part of a larger module, I'd like to save a series of images using matplotlib. There will only be a small number of images (i.e. 1-3), so I'm trying to dynamically set the number of columns specified in plt.subplot. When I try to do this for three images, everything is fine - I get a row of 3 images:
However, when I try to do this with only two images, I get a single image on the right hand side of the subplot:
I've extracted out the key method of the original code and created an example (see below) that feeds an image array to that method with exactly the same format as the actual code.
Why am I able to get a row of 3 images, but not of 2 images?
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg') # suppresses plot
import numpy as np
from PIL import Image, ImageOps
def generate_images(test_input, path_filename):
plt.figure(figsize=(15, 6))
# This configuration works - Gives 3 images
#display_list = [test_input[0], test_input[0], test_input[0]]
#title = ['Test Image 1', 'Test Image 2', 'Test Image 3']
# This configuration does not work - Only gives 1 image on right side
# of subplot
display_list = [test_input[0], test_input[0]]
title = ['Test Image 1', 'Test Image 2']
for i in range(len(title)):
# Here is where I tried to dynamically create my subplot dimensions
plt.subplot(1, len(title), i 1)
plt.title(title[i])
# Getting the pixel values in the [0, 1] range to plot.
plt.imshow(display_list[i] * 0.5 0.5, cmap=plt.get_cmap('gray'))
plt.axis('off')
plt.tight_layout()
plt.savefig(path_filename, dpi=200)
plt.close()
print()
print("Image Shape:",test_input.shape)
print("len(title):",len(title))
# Get Sample Image
file_path = <path to input image>
im = Image.open(file_path)
# Rescale and change from RGB to grayscale
im2 = ImageOps.grayscale(im)
im2 = im2.resize((256,256))
num_array = np.asarray(im2)
# Convert to an array of images
num_array = num_array[np.newaxis,:,:,np.newaxis]
out_path = <path and filename of output image>
generate_images(num_array,out_path)
CodePudding user response:
The two images are plotted, for some reason, the problem arises when you want to use: plt.tight_layout()
.
try to call this function after all axes have been added
Observe how I put the call to the function after the for loop, this should solve your problem
You can get more information about this function in the following link: matplotlib.pyplot.tight_layout