I am attemption to create a function to serve as a quick visual assessment for a normal distribution and to automate this for a whole dataframe. I want to create a no. of cols x 2 subplot (2 columns, each column of a dataframe a row) with the left plot being a histogram and the right a probability plot. I have written functions for each of these plots and they work fine, and the ax argument I have added can successfully plot them in a specific subplot coordinate. When I try to call these functions in a final function, intended to apply these to each column in a dataframe only the first histogram is returned and the rest of the plots empty.
Not sure where I am going wrong. See code for functions below. Note, no errors are returned:
#Histogram for normality
def normal_dist_hist(data, ax):
#Format data for plotting
#Included ax for subplot coordinate
if data.isnull().values.any() == True:
data.dropna(inplace=True)
if data.dtypes == 'float64':
data.astype('int64')
#Plot distribution with Gaussian overlay
mu, std = stats.norm.fit(data)
ax.hist(data, bins=50, density=True, alpha=0.6, color='g')
xmin, xmax = ax.get_xlim()
x = np.linspace(xmin, xmax, 100)
p = stats.norm.pdf(x, mu, std)
ax.plot(x, p, 'k', linewidth=2)
title = "Fit results: mu = %.2f, std = %.2f" % (mu, std)
ax.set_title(title)
plt.show()
#Probability plot
def normal_test_QQplots(data, ax):
#added ax argument for specifying subplot coordinate,
data.dropna(inplace=True)
probplt = stats.probplot(data,dist='norm',fit=True,plot=ax)
plt.show()
def normality_report(df):
fig, axes = plt.subplots(nrows=len(df.columns), ncols=2,figsize=(12,50))
ax_y = 0
for col in df.columns[1:]:
ax_x = 0
normal_dist_hist(df[col], ax=axes[ax_y, ax_x])
ax_x = 1
normal_test_QQplots(df[col], ax=axes[ax_y, ax_x])
ax_y = 1
CodePudding user response:
Remove the plt.show()
from your methods normal_dist_hist(...)
and normal_test_QQplots(...)
. Add plt.show()
at the end of your normality_report(...)
.
def normal_dist_hist(data, ax):
...
plt.show() # Remove this
#Probability plot
def normal_test_QQplots(data, ax):
...
plt.show() # Remove this
def normality_report(df):
...
for col in df.columns[1:]:
ax_x = 0
normal_dist_hist(df[col], ax=axes[ax_y, ax_x])
ax_x = 1
normal_test_QQplots(df[col], ax=axes[ax_y, ax_x])
ax_y = 1
plt.show() # Add it here.