I am trying to generate 16 subgraphs, My end goal is to have an 8 x 2 size for the final graph, My code looks like this:
def visualize_t2t(token_dict, scores):
fig = plt.figure(figsize=(50, 50))
for idx, scores in enumerate(scores):
scores_np = np.array(scores)
ax = fig.add_subplot(12, 12, idx 1)
# append the attention weights
im = ax.imshow(scores, cmap='viridis')
fontdict = {'fontsize': 3}
ax.set_xticks(range(len(all_tokens)))
ax.set_yticks(range(len(all_tokens)))
ax.set_xticklabels(all_tokens, fontdict=fontdict, rotation=90)
ax.set_yticklabels(all_tokens, fontdict=fontdict)
ax.set_xlabel('{} {}'.format('label_name', idx 1))
fig.colorbar(im, fraction=0.046, pad=0.04)
plt.tight_layout()
name_f = str(uuid.uuid4())
plt.savefig(f'{name_f}.pdf',
bbox_inches='tight',
dpi=350)
Input data
all_tokens = ['[CLS]',
'what',
'type',
'of',
'heart',
'issue',
'does',
'the',
'Person',
'have',
'[CLS]']
dummy_input = np.random.uniform(-1, 1, [16, len(all_tokens), len(all_tokens)])
visualize_t2t(all_tokens, dummy_input)
But the result looks like this:
How can I set the rows and col here to have 8 subgraphs in one row and remaining in another?
CodePudding user response:
Just replace ax = fig.add_subplot(12, 12, idx 1)
with ax = fig.add_subplot(2, 8, idx 1)
.