So I have this code:
def scatter(df, column_name):
values = {data: list(df[data]) for data in column_name}
data = list(values.values())
labels = list(values.keys())
for i in range(len(data)):
for j in range(len(data)):
if i == j:
continue
elif (i == 1) & (j == 0):
continue
elif (i == 2) & ((j == 0)|(j == 1)):
continue
elif (i == 3) & ((j == 0)|(j == 1)|(j == 2)):
continue
else:
for k in range(6):
ax = plt.subplot(3, 2, k 1)
plt.scatter(data[i], data[j])
plt.xlabel(labels[i])
plt.ylabel(labels[j])
plt.title('{} vs {}'.format(labels[i], labels[j]))
plt.show()
plt.clf()
scatter(roller_coasters, ['speed', 'height', 'length', 'num_inversions'])
but it produces 6 figures instead of 1 and each figure has the same graph repeated 6 times.
Please help me solve this.
CodePudding user response:
Well for each time you enter the else
part of your loop, you create 6 subplots for that given i,j
combination. E.g. for i=0; j=1
the loop for k
creates six subplots but only for that specific i
and j
. And when, created, the figure is closed again (plt.clf()
). The following i=0; j=2
a next set of 6 subplots is created.
You can simplify things by letting the loop over j
start at i 1
, so no tests are needed. Also, the value for which subplot will be created next, can be a variable k
that is incremented each time a subplot has been added.
Here is some example code:
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
def scatter(df, column_names):
fig = plt.figure(figsize=(10, 12)) # set a size for the surrounding plot
n = len(column_names)
total = n * (n - 1) // 2
ncols = 2
nrows = (total (ncols - 1)) // ncols
k = 1
for i in range(n):
col_i = column_names[i]
for j in range(i 1, n):
col_j = column_names[j]
ax = plt.subplot(nrows, ncols, k)
plt.scatter(df[col_i], df[col_j])
plt.xlabel(col_i)
plt.ylabel(col_j)
plt.title(f'{col_i} vs {col_j}')
k = 1
plt.tight_layout() # fit labels and ticks nicely together
plt.show() # only called once, at the end of the function
columns = ['speed', 'height', 'length', 'num_inversions']
roller_coasters = pd.DataFrame(np.random.rand(20, len(columns)), columns=columns)
scatter(roller_coasters, ['speed', 'height', 'length', 'num_inversions'])