Home > Enterprise >  Fast way of saving many images with matplotlib and for loop
Fast way of saving many images with matplotlib and for loop

Time:11-09

I noticed that there is a big performance gap between the following two solutions to load and save images with matplotlib. Can anyone explain why, and what is the best (and fastest) way to save images in a python for loop?

Implementation 1: create a figure outside the for loop, update what is displayed and then save.

fig, a = plt.subplots(1, 3, figsize=(30, 20))  # <--------

# list_of_fnames is just a list of file names
for k, fname in enumerate(list_of_fnames):
    with Image.open(fname) as img:
        x = np.array(img)

    y = process_image_fn1(x)
    z = process_image_fn2(x)

    a[0].imshow(x)
    a[1].imshow(y)
    a[2].imshow(z)

    output_filename = f'results_{k}.png'
    plt.savefig(output_filename, dpi=320, format='png', transparent=False, bbox_inches='tight', pad_inches=0)

Implementation 2: create a figure inside the for loop, save it, finally destroy it.

# list_of_fnames is just a list of file names
for k, fname in enumerate(list_of_fnames):
    with Image.open(fname) as img:
        x = np.array(img)

    y = process_image_fn1(x)
    z = process_image_fn2(x)

    fig, a = plt.subplots(1, 3, figsize=(30, 20))  # <--------
    a[0].imshow(x)
    a[1].imshow(y)
    a[2].imshow(z)

    output_filename = f'results_{k}.png'
    plt.savefig(output_filename, dpi=320, format='png', transparent=False, bbox_inches='tight', pad_inches=0)

    plt.close()  # <--------

CodePudding user response:

The first option could be improved in a couple of ways.

  1. removing the previously plotted AxesImages (from imshow), so that you don't keep increasing the number of plotted images on the axes
fig, a = plt.subplots(1, 3, figsize=(30, 20))  # <--------

# list_of_fnames is just a list of file names
for k, fname in enumerate(list_of_fnames):

    for ax in a:
        ax.images.pop()

    with Image.open(fname) as img:
        x = np.array(img)

    y = process_image_fn1(x)
    z = process_image_fn2(x)

    a[0].imshow(x)
    a[1].imshow(y)
    a[2].imshow(z)

    output_filename = f'results_{k}.png'
    plt.savefig(output_filename, dpi=320, format='png', transparent=False, bbox_inches='tight', pad_inches=0)
  1. alternatively, create the AxesImages once per axes, then rather than replot them each iteration, use .set_array() to change what is plotted on the AxesImage
fig, a = plt.subplots(1, 3, figsize=(30, 20))  # <--------

# list_of_fnames is just a list of file names
for k, fname in enumerate(list_of_fnames):
    with Image.open(fname) as img:
        x = np.array(img)

    y = process_image_fn1(x)
    z = process_image_fn2(x)

    if k == 0:
        im0 = a[0].imshow(x)
        im1 = a[1].imshow(y)
        im2 = a[2].imshow(z)
    else:
        im0.set_array(x)
        im1.set_array(y)
        im2.set_array(z)

    output_filename = f'results_{k}.png'
    plt.savefig(output_filename, dpi=320, format='png', transparent=False, bbox_inches='tight', pad_inches=0)
  • Related