Home > Software design >  How to set col and rows in add_subplot?
How to set col and rows in add_subplot?

Time:11-07

I am trying to generate 16 subgraphs, My end goal is to have an 8 x 2 size for the final graph, My code looks like this:

def visualize_t2t(token_dict, scores):

  fig = plt.figure(figsize=(50, 50))

  for idx, scores in enumerate(scores):
      scores_np = np.array(scores)

      ax = fig.add_subplot(12, 12, idx 1)
      # append the attention weights
      im = ax.imshow(scores, cmap='viridis')

      fontdict = {'fontsize': 3}

      ax.set_xticks(range(len(all_tokens)))
      ax.set_yticks(range(len(all_tokens)))
      ax.set_xticklabels(all_tokens, fontdict=fontdict, rotation=90)
      ax.set_yticklabels(all_tokens, fontdict=fontdict)
      ax.set_xlabel('{} {}'.format('label_name', idx 1))

      fig.colorbar(im, fraction=0.046, pad=0.04)
  
  plt.tight_layout()
  name_f = str(uuid.uuid4())
  plt.savefig(f'{name_f}.pdf',
  bbox_inches='tight',
    dpi=350)

Input data

all_tokens = ['[CLS]',
 'what',
 'type',
 'of',
 'heart',
 'issue',
 'does',
 'the',
 'Person',
 'have',
 '[CLS]']

dummy_input = np.random.uniform(-1, 1, [16, len(all_tokens), len(all_tokens)])
visualize_t2t(all_tokens, dummy_input)

But the result looks like this:

enter image description here

How can I set the rows and col here to have 8 subgraphs in one row and remaining in another?

CodePudding user response:

Just replace ax = fig.add_subplot(12, 12, idx 1) with ax = fig.add_subplot(2, 8, idx 1).

  • Related