Home > Back-end >  add grid to plt.subplots
add grid to plt.subplots

Time:07-23

I want to add grid to subplots of plt like: enter image description here

and this is my code(data is a dict of (string, padas DataFrame)):

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)
for row in ((ax1, ax2), (ax3, ax4)):
    for axes in row:
        data['DPTP'].plot(ax=axes, x="time", y="fitness",
                          label="DPTP(r = 30)", linestyle='dashed', grid=True)

for i in range(2):
    lb = "Preset {}".format(i)
    data[str(i)].plot(ax=ax2, x="time", y="fitness", label=lb, grid=True)
ax1.set_title("a")
ax1.grid(True)

for i in range(2, 4):
    lb = "Preset {}".format(i)
    data[str(i)].plot(ax=ax3, x="time", y="fitness", label=lb, grid=True)
ax2.set_title("b")
ax2.grid(True)

for i in range(4, 6):
    lb = "Preset {}".format(i)
    data[str(i)].plot(ax=ax4, x="time", y="fitness", label=lb, grid=True)
ax3.set_title("c")
ax3.grid(True)

for i in range(6, 10):
    lb = "Preset {}".format(i)
    data[str(i)].plot(ax=ax4, x="time", y="fitness", label=lb, grid=True)
ax4.set_title("d")
ax4.grid(True)

plt.show()

in the each for-loop I have grid=True and I think I can set each subplot to have grid, but my plot only has grid on the first subplot: ![enter image description here And even ax.set_title() does not work for subplots. Why?

CodePudding user response:

You're somewhat new so I'll be kind and provide a solution but you should have provided the mock data to work on, that's your job when asking a question here.

So here is the part you should have provided, construction of data:

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

# Mock data
time = np.linspace(0.3, 2, 600)
fitness = lambda i: 1 / (time   0.05 * (i   1)) - 0.1 * (i   1)
dataframe = lambda i: pd.DataFrame({"time": time, "fitness": fitness(i)})

data = {"DPTP": dataframe(0)} | {str(i): dataframe(i   1) for i in range(10)}

Now let's create the axes, I just refactored a bit using dicts so plotting part is shorter and independant:

# Prepare axes and axes metadata
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)

ax_titles = {ax1: "a", ax2: "b", ax3: "c", ax4: "c"}
key_to_ax = (
    dict.fromkeys(["0", "1"], ax1)
    | dict.fromkeys(["2", "3"], ax2)
    | dict.fromkeys(["4", "5"], ax3)
    | dict.fromkeys(["6", "7", "8", "9"], ax4)
)

And finally the plotting part. Here you have two choices, either:

  • set grid=True as argument to all the df.plot calls
  • use ax.grid() once you're done plotting new lines on that ax.
# Plot
for ax in (ax1, ax2, ax3, ax4):
    data["DPTP"].plot(
        ax=ax, x="time", y="fitness", label="DPTP(r = 30)", linestyle="dashed"
    )

for i in range(10):
    data[str(i)].plot(ax=key_to_ax[str(i)], x="time", y="fitness", label=f"Preset {i}")

for ax in (ax1, ax2, ax3, ax4):
    ax.set_title(ax_titles[ax])
    ax.grid()

plt.show()

subplots with grids

  • Related