Home > Back-end >  Subplots repeating the same graph 6 times and producing 6 figures instead of one
Subplots repeating the same graph 6 times and producing 6 figures instead of one

Time:10-05

So I have this code:

def scatter(df, column_name):
  values = {data: list(df[data]) for data in column_name}

  data = list(values.values())
  labels = list(values.keys())
  
  for i in range(len(data)):
    for j in range(len(data)):
      if i == j:
        continue
      elif (i == 1) & (j == 0):
        continue
      elif (i == 2) & ((j == 0)|(j == 1)):
        continue
      elif (i == 3) & ((j == 0)|(j == 1)|(j == 2)):
        continue
      else:
        for k in range(6):
          ax = plt.subplot(3, 2, k 1)
          plt.scatter(data[i], data[j])
          plt.xlabel(labels[i])
          plt.ylabel(labels[j])
          plt.title('{} vs {}'.format(labels[i], labels[j]))
        plt.show()
        plt.clf()

scatter(roller_coasters, ['speed', 'height', 'length', 'num_inversions'])

but it produces 6 figures instead of 1 and each figure has the same graph repeated 6 times.

Please help me solve this.

CodePudding user response:

Well for each time you enter the else part of your loop, you create 6 subplots for that given i,j combination. E.g. for i=0; j=1 the loop for k creates six subplots but only for that specific i and j. And when, created, the figure is closed again (plt.clf()). The following i=0; j=2 a next set of 6 subplots is created.

You can simplify things by letting the loop over j start at i 1, so no tests are needed. Also, the value for which subplot will be created next, can be a variable k that is incremented each time a subplot has been added.

Here is some example code:

from matplotlib import pyplot as plt
import pandas as pd
import numpy as np

def scatter(df, column_names):
    fig = plt.figure(figsize=(10, 12)) # set a size for the surrounding plot
    n = len(column_names)
    total = n * (n - 1) // 2
    ncols = 2
    nrows = (total   (ncols - 1)) // ncols
    k = 1
    for i in range(n):
        col_i = column_names[i]
        for j in range(i   1, n):
            col_j = column_names[j]
            ax = plt.subplot(nrows, ncols, k)
            plt.scatter(df[col_i], df[col_j])
            plt.xlabel(col_i)
            plt.ylabel(col_j)
            plt.title(f'{col_i} vs {col_j}')
            k  = 1
    plt.tight_layout() # fit labels and ticks nicely together
    plt.show() # only called once, at the end of the function

columns = ['speed', 'height', 'length', 'num_inversions']
roller_coasters = pd.DataFrame(np.random.rand(20, len(columns)), columns=columns)
scatter(roller_coasters, ['speed', 'height', 'length', 'num_inversions'])

6 subplots

  • Related