I am using plt.subplots()
to create a subplot with 6 plots.
I want two different colormaps per column. Row 1 shows some approximate solutions and Row 2 shows the absolute pointwise difference from the true solution. I thought why not add another column for the colormap and remove everything except the colormap from that axis. Was able to remove the x and y ticks only (see the ax[0,3]
).
MATPLOTLIB has a documentation here, but its too specfic. Does anyone knows a shortcut to do the same.
Here is my code:
fig, ax = plt.subplots(2, 4,dpi=300)
#ax = fig.add_subplot(111)#, projection='3d')
n_points = 1000
# Intepolation points
x = np.linspace(lb[0], ub[0], n_points)
y = np.linspace(lb[1], ub[1], n_points)
# Create meshgrid
X, Y = np.meshgrid(x,y)
# PINN predictions
new_data = griddata(nodes, u_pred_5k.flatten(), (X, Y), method='linear')
sc1 = ax[0,0].imshow(new_data, interpolation='nearest', cmap='jet',
extent=[nodes[:,0].min(), nodes[:,0].max(), nodes[:,1].min(), nodes[:,1].max()],
origin='lower', aspect='equal',vmin=0, vmax=1)
new_data = griddata(nodes, u_pred_10k.flatten(), (X, Y), method='linear')
sc1 = ax[0,1].imshow(new_data, interpolation='nearest', cmap='jet',
extent=[nodes[:,0].min(), nodes[:,0].max(), nodes[:,1].min(), nodes[:,1].max()],
origin='lower', aspect='equal',vmin=0, vmax=1)
new_data = griddata(nodes, u_pred_30k.flatten(), (X, Y), method='linear')
sc1 = ax[0,2].imshow(new_data, interpolation='nearest', cmap='jet',
extent=[nodes[:,0].min(), nodes[:,0].max(), nodes[:,1].min(), nodes[:,1].max()],
origin='lower', aspect='equal',vmin=0, vmax=1)
# Absolute error
new_data = griddata(nodes, abs(u_pred_5k-temperature).flatten(), (X, Y), method='linear')
sc2 = ax[1,0].imshow(new_data, interpolation='nearest', cmap='jet',
extent=[nodes[:,0].min(), nodes[:,0].max(), nodes[:,1].min(), nodes[:,1].max()],
origin='lower', aspect='equal')
new_data = griddata(nodes, abs(u_pred_10k-temperature).flatten(), (X, Y), method='linear')
sc2 = ax[1,1].imshow(new_data, interpolation='nearest', cmap='jet',
extent=[nodes[:,0].min(), nodes[:,0].max(), nodes[:,1].min(), nodes[:,1].max()],
origin='lower', aspect='equal')
new_data = griddata(nodes, abs(u_pred_30k-temperature).flatten(), (X, Y), method='linear')
sc2 = ax[1,2].imshow(new_data, interpolation='nearest', cmap='jet',
extent=[nodes[:,0].min(), nodes[:,0].max(), nodes[:,1].min(), nodes[:,1].max()],
origin='lower', aspect='equal')
#sc = ax.scatter(nodes[:,0], nodes[:,1], c = best_u_pred, cmap=cm.jet, vmin=0, vmax=1)
#plt.colorbar(sc)
sc1.cmap.set_under('k')
sc1.cmap.set_over('k')
ax[0,0].set(xlabel='x', ylabel='y')
ax[0,1].set(xlabel='x', ylabel='y')
ax[0,2].set(xlabel='x', ylabel='y')
ax[1,0].set(xlabel='x', ylabel='y')
ax[1,1].set(xlabel='x', ylabel='y')
ax[1,2].set(xlabel='x', ylabel='y')
ax[0,0].set_title('5k iterations')
ax[0,1].set_title('10k iterations')
ax[0,2].set_title('30k iterations')
fig.colorbar(sc1, ax=ax[0,3])
ax[0,3].get_xaxis().set_visible(False)
ax[0,3].get_yaxis().set_visible(False)
#plt.xlabel('x')
#plt.ylabel('y')
#ax.set_zlabel('z')
#plt.title('Temperature distribution')
#fig.supxlabel('x')
#fig.supylabel('y')
fig.tight_layout()
CodePudding user response:
Here I had to come up with some data, so I change your code a little bit, but the essence remains. What I would have done differently is turning off the axis with ax[0, 3].axis(False)
and ax[1, 3].axis(False)
.
fig, ax = plt.subplots(2, 4)
# Intepolation points
x = np.linspace(-2, 2, 50)
y = np.linspace(-2, 2, 50)
# Create meshgrid
X, Y = np.meshgrid(x,y)
Z1 = np.cos(X**2 Y**2)
Z2 = np.sin(X**2 Y**2)
sc = []
c = 0
for i in range(2):
for j in range(3):
sc.append(ax[i, j].contourf(X, Y, Z1 if c < 3 else Z2, cmap=cm.viridis if c < 3 else cm.magma))
ax[i, j].set_aspect("equal")
c = 1
ax[0, 3].axis(False)
ax[1, 3].axis(False)
plt.colorbar(sc[0], ax=ax[0, 3])
plt.colorbar(sc[3], ax=ax[1, 3])
ax[0,0].set(xlabel='x', ylabel='y')
ax[0,1].set(xlabel='x', ylabel='y')
ax[0,2].set(xlabel='x', ylabel='y')
ax[1,0].set(xlabel='x', ylabel='y')
ax[1,1].set(xlabel='x', ylabel='y')
ax[1,2].set(xlabel='x', ylabel='y')
ax[0,0].set_title('5k iterations')
ax[0,1].set_title('10k iterations')
ax[0,2].set_title('30k iterations')
fig.tight_layout()