I need to recreate this structure in matplotlib
import matplotlib.pyplot as plt
fig, (ax1, ax2, ax3, ax4) = plt.subplots(ncols=1, nrows=4)
# fig.tight_layout()
ax1.axis("off")
ax2.axis("off")
ax3.axis("off")
ax4.axis("off")
ax1.table(cellText=[["A", "B", "C", "D", "E", "F"]])
ax2.table(cellText=[[1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3],
[4, 4, 4, 4, 4, 4],
[5, 5, 5, 5, 5, 5],
[6, 6, 6, 6, 6, 6],
[7, 7, 7, 7, 7, 7],
[8, 8, 8, 8, 8, 8],
[9, 9, 9, 9, 9, 9]])
ax3.table(cellText=[[1, 2, 3, 4, 5, 6],
[1, 2, 3, 4, 5, 6]])
ax4.table(cellText=[[1, 2, 3, 4, 5, 6]])
plt.show()
Here is what I have so far, but the tables are overlapping and it's not useable. So my question is how can I achieve make this table structure in matplotlib?
CodePudding user response:
The problem you're running into is that subplots
assumes that all of your tables will be the same size. Since this is not the case, you end up with tables that are larger the the underlying Axes
objects.
If your case you have your tables in advance so you can estimate their sizes based on the number of rows. You can pass this information into the subplots
constructor to proportionally resize the returned Axes such that each table should fit into their own axes.
Here I predefine all of the tables in an overarching list (for convenience mainly). I can then get the number of rows in each table and pass that to the height_ratios
gridspec argument such that each produced Axes
has a size that is proportional to the number of rows in its respective table.
Then I can zip the returned axes with the tables and actually insert the values.
Lastly, you should pass loc='center'
so each table is actually drawn into the Axes
instead of below, above, or off to the side.
import matplotlib.pyplot as plt
tables = [
[["A", "B", "C", "D", "E", "F"]],
[[1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3],
[4, 4, 4, 4, 4, 4],
[5, 5, 5, 5, 5, 5],
[6, 6, 6, 6, 6, 6],
[7, 7, 7, 7, 7, 7],
[8, 8, 8, 8, 8, 8],
[9, 9, 9, 9, 9, 9]],
[[1, 2, 3, 4, 5, 6],
[1, 2, 3, 4, 5, 6]],
[[1, 2, 3, 4, 5, 6]]
]
table_rows = [len(tbl) for tbl in tables]
fig, axes = plt.subplots(
ncols=1, nrows=len(tables),
gridspec_kw={'height_ratios': table_rows}
)
for tbl, ax in zip(tables, axes.flat):
ax.table(cellText=tbl, loc='center')
ax.axis('off')
plt.show()