I would like to get a plot with less data on xaxis. I have this very simple script. I put a 'range' for xaxis. Furthemore I would like that my background was white with contours black and not grey (see figure). How can I do?
import matplotlib.pyplot as plt
plt.figure()
# Increase the plot size and font size.
plt.rcParams["figure.figsize"] = (60,30)
plt.xticks(fontsize=40)
plt.yticks(fontsize=40)
plt.grid(True, color='gray', linestyle='dashed', linewidth=0.5, axis='y')
# Plot the learning curve.
plt.plot(df_stats['Training Loss'], color='b', marker='.', linestyle='solid', mec='b', markersize=24, markerfacecolor='white', label="Training", linewidth=7)
plt.plot(df_stats['Valid. Loss'], color='g', marker='.', linestyle='solid', mec='b', markersize=24, markerfacecolor='white',label="Validation", linewidth=7)
# Label the plot.
plt.title("Training & Validation Loss",fontsize=60)
plt.xlabel("Epoch", fontsize=52)
plt.ylabel("Loss", fontsize=52)
plt.legend(fontsize=50)
plt.xticks(list(range(1, 72)))
plt.show()
CodePudding user response:
To set the X-axis ticks to a lower frequency, you will need to change the xticks
to a lower frequency. One way to do this using numpy.arange()
.
Regarding the background color, the default is white. But, if for reason it is not, you can set it to white using the plot and axis facecolor()
to white explicitly.
The code below is the updated version with these changes. Note that I used some dummy data to demonstrate the same.
Code
df_stats= pd.DataFrame(columns=['Training Loss', 'Valid. Loss'])
df_stats['Training Loss'] = list(range(1,72))
df_stats['Valid. Loss'] = df_stats['Training Loss'] * 2.1
import matplotlib.pyplot as plt
plt.figure()
# Background color of outer area
plt.figure(facecolor='white')
# Background color of the plot area
ax = plt.axes()
ax.set_facecolor("white")
# Increase the plot size and font size.
plt.rcParams["figure.figsize"] = (60,30)
plt.xticks(fontsize=40)
plt.yticks(fontsize=40)
plt.grid(True, color='gray', linestyle='dashed', linewidth=0.5, axis='y')
# Plot the learning curve.
plt.plot(df_stats['Training Loss'], color='b', marker='.', linestyle='solid', mec='b', markersize=24, markerfacecolor='white', label="Training", linewidth=7)
plt.plot(df_stats['Valid. Loss'], color='g', marker='.', linestyle='solid', mec='b', markersize=24, markerfacecolor='white',label="Validation", linewidth=7)
# Label the plot.
plt.title("Training & Validation Loss",fontsize=60)
plt.xlabel("Epoch", fontsize=52)
plt.ylabel("Loss", fontsize=52)
plt.legend(fontsize=50)
plt.xticks(list(np.arange(1, 72, 9)))
plt.show()
Output plot