I want to add grid to subplots of plt like:
and this is my code(data
is a dict of (string, padas DataFrame)):
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)
for row in ((ax1, ax2), (ax3, ax4)):
for axes in row:
data['DPTP'].plot(ax=axes, x="time", y="fitness",
label="DPTP(r = 30)", linestyle='dashed', grid=True)
for i in range(2):
lb = "Preset {}".format(i)
data[str(i)].plot(ax=ax2, x="time", y="fitness", label=lb, grid=True)
ax1.set_title("a")
ax1.grid(True)
for i in range(2, 4):
lb = "Preset {}".format(i)
data[str(i)].plot(ax=ax3, x="time", y="fitness", label=lb, grid=True)
ax2.set_title("b")
ax2.grid(True)
for i in range(4, 6):
lb = "Preset {}".format(i)
data[str(i)].plot(ax=ax4, x="time", y="fitness", label=lb, grid=True)
ax3.set_title("c")
ax3.grid(True)
for i in range(6, 10):
lb = "Preset {}".format(i)
data[str(i)].plot(ax=ax4, x="time", y="fitness", label=lb, grid=True)
ax4.set_title("d")
ax4.grid(True)
plt.show()
in the each for-loop I have grid=True
and I think I can set each subplot to have grid, but my plot only has grid on the first subplot:
And even ax.set_title()
does not work for subplots.
Why?
CodePudding user response:
You're somewhat new so I'll be kind and provide a solution but you should have provided the mock data to work on, that's your job when asking a question here.
So here is the part you should have provided, construction of data:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
# Mock data
time = np.linspace(0.3, 2, 600)
fitness = lambda i: 1 / (time 0.05 * (i 1)) - 0.1 * (i 1)
dataframe = lambda i: pd.DataFrame({"time": time, "fitness": fitness(i)})
data = {"DPTP": dataframe(0)} | {str(i): dataframe(i 1) for i in range(10)}
Now let's create the axes, I just refactored a bit using dicts so plotting part is shorter and independant:
# Prepare axes and axes metadata
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)
ax_titles = {ax1: "a", ax2: "b", ax3: "c", ax4: "c"}
key_to_ax = (
dict.fromkeys(["0", "1"], ax1)
| dict.fromkeys(["2", "3"], ax2)
| dict.fromkeys(["4", "5"], ax3)
| dict.fromkeys(["6", "7", "8", "9"], ax4)
)
And finally the plotting part. Here you have two choices, either:
- set
grid=True
as argument to all thedf.plot
calls - use
ax.grid()
once you're done plotting new lines on thatax
.
# Plot
for ax in (ax1, ax2, ax3, ax4):
data["DPTP"].plot(
ax=ax, x="time", y="fitness", label="DPTP(r = 30)", linestyle="dashed"
)
for i in range(10):
data[str(i)].plot(ax=key_to_ax[str(i)], x="time", y="fitness", label=f"Preset {i}")
for ax in (ax1, ax2, ax3, ax4):
ax.set_title(ax_titles[ax])
ax.grid()
plt.show()