I have the following code and I am trying to remove the space between the 3 subplots.
I have tried both gridspec.GridSpec
and fig.subplots_adjust
and both fail.
The individual colorbars should be there as they are now.
Any idea?
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.axes_grid1 import make_axes_locatable
# random data
x1 = np.random.rand(1,10)
x2 = np.random.rand(1,10)
x3 = np.random.rand(1,10)
nrow, ncol = 3, 1
fig, axes = plt.subplots(nrow,ncol, figsize=(12,8))
fig.subplots_adjust(wspace=0.01, hspace=0.01)
axes = axes.flatten()
cmaps = ['Greens_r', 'Reds', 'jet']
x_all = [x1,x2,x3]
for i in range(3):
im = axes[i].imshow(x_all[i], cmap=cmaps[i])
#plt.colorbar(im, ax=axes[i])
axes[i].set_xticklabels([])
axes[i].set_yticklabels([])
axes[i].set_xticks([])
axes[i].set_yticks([])
#Make an axis for the colorbar on the right side
divider = make_axes_locatable(axes[i])
cax = divider.append_axes("right", size="5%", pad=0.05)
fig.colorbar(im, cax=cax)
CodePudding user response:
You can first draw you figure as is and then adjust the position of the top and bottom axes relative to the middle one by moving them by the amount of the gap between them minus a certain space
. This is easiest if you follow Jody's advise to use inset_axes
.
import matplotlib.pyplot as plt
import matplotlib.transforms as mt
import numpy as np
# random data
np.random.seed(42)
x_all = (np.random.rand(1,10), np.random.rand(1,10), np.random.rand(1,10))
cmaps = ['Greens_r', 'Reds', 'jet']
space = 0.01
fig, axes = plt.subplots(3, 1, figsize=(12,6))
for ax, x, cmap in zip(axes, x_all, cmaps):
ax.axis('off')
im = ax.imshow(x, cmap=cmap)
cax = ax.inset_axes([1 space, 0, 0.05, 1])
fig.colorbar(im, ax=ax, cax=cax)
fig.canvas.draw()
t = mt.Affine2D().translate(0, axes[1].get_position().y1 - axes[0].get_position().y0 space)
axes[0].set_position(mt.TransformedBbox(axes[0].get_position(), t))
t = mt.Affine2D().translate(0, axes[1].get_position().y0 - axes[2].get_position().y1 - space)
axes[2].set_position(mt.TransformedBbox(axes[2].get_position(), t))