I want to draw several plots in one figure.
To plot one set of data, I defined a function that returns ax:
def make_plot(y_true, y_pred, plot_size) :
fig = plt.figure(figsize=(plot_size))
ax = fig.add_subplot(1,1,1)
ax.plot(y_true, y_pred, 'o')
ax.set_xlabel('Observed', size=14)
ax.set_ylabel('Predicted', size=14)
ax.tick_params(labelsize=12)
return ax
Then, I have a For loop over data to make a plot for every set of data, and combine all plots in one figure. Following is the code:
import matplotlib.pyplot as plt
import math
def make_plot(y_true, y_pred, plot_size) :
fig = plt.figure(figsize=(plot_size))
ax = fig.add_subplot(1,1,1)
ax.plot(y_true, y_pred, 'o')
ax.set_xlabel('Observed', size=14)
ax.set_ylabel('Predicted', size=14)
ax.tick_params(labelsize=12)
return ax
def plot_all(y_true_all, y_pred_all, fig_save_folder, fig_name, plot_size=(4,4), num_plots_x = 2):
num_plots_y = math.ceil(len(y_true_all)/num_plots_x) # No. of plots in y direction
plt.figure(figsize=(plot_size[0]*num_plots_x, plot_size[1]*num_plots_y))
for i in range(len((y_true_all))):
ax = plt.subplot(num_plots_y, num_plots_x, i 1) # [row, column]
y_true = y_true_all[i]
y_pred = y_pred_all[i]
ax = make_plot(y_true, y_pred, plot_size)
plt.tight_layout()
plt.savefig(f'{fig_save_folder}/{fig_name}.png')
plt.show()
y_true_all = [[1, 2, 3], [1, 2, 3]]
y_pred_all = [[1.1, 2, 3.1], [1, 1.9, 3]]
fig_save_folder = './result'
fig_name = 'test'
plot_all(y_true_all, y_pred_all, fig_save_folder, fig_name, plot_size=(4,4), num_plots_x = 2)
What I want to get is a figure like below:
However, I get two empty axes and one plot. If you have any idea to solve this issue, please let me know.
CodePudding user response:
Eaiser and better to just pass your function an Axes
:
import matplotlib.pyplot as plt
import math
def make_plot(y_true, y_pred, ax) :
ax.plot(y_true, y_pred, 'o')
ax.set_xlabel('Observed', size=14)
ax.set_ylabel('Predicted', size=14)
ax.tick_params(labelsize=12)
return ax
def plot_all(y_true_all, y_pred_all, fig_save_folder, fig_name, plot_size=(4,4), num_plots_x = 2):
num_plots_y = math.ceil(len(y_true_all)/num_plots_x) # No. of plots in y direction
plt.figure(figsize=(plot_size[0]*num_plots_x, plot_size[1]*num_plots_y))
for i in range(len((y_true_all))):
ax = plt.subplot(num_plots_y, num_plots_x, i 1) # [row, column]
y_true = y_true_all[i]
y_pred = y_pred_all[i]
ax = make_plot(y_true, y_pred, ax)
plt.tight_layout()
plt.savefig(f'{fig_save_folder}/{fig_name}.png')
plt.show()
y_true_all = [[1, 2, 3], [1, 2, 3]]
y_pred_all = [[1.1, 2, 3.1], [1, 1.9, 3]]
fig_save_folder = './result'
fig_name = 'test'
plot_all(y_true_all, y_pred_all, fig_save_folder, fig_name, plot_size=(4,4), num_plots_x = 2)
Output: