Home > OS >  Plotting matplotlib subplots with functions
Plotting matplotlib subplots with functions

Time:04-05

I am attemption to create a function to serve as a quick visual assessment for a normal distribution and to automate this for a whole dataframe. I want to create a no. of cols x 2 subplot (2 columns, each column of a dataframe a row) with the left plot being a histogram and the right a probability plot. I have written functions for each of these plots and they work fine, and the ax argument I have added can successfully plot them in a specific subplot coordinate. When I try to call these functions in a final function, intended to apply these to each column in a dataframe only the first histogram is returned and the rest of the plots empty.

Not sure where I am going wrong. See code for functions below. Note, no errors are returned:

#Histogram for normality
def normal_dist_hist(data, ax):
    #Format data for plotting
    #Included ax for subplot coordinate
    if data.isnull().values.any() == True:
        data.dropna(inplace=True)
    if data.dtypes == 'float64':
        data.astype('int64')
    #Plot distribution with Gaussian overlay
    mu, std = stats.norm.fit(data)
    ax.hist(data, bins=50, density=True, alpha=0.6, color='g')
    xmin, xmax = ax.get_xlim()
    x = np.linspace(xmin, xmax, 100)
    p = stats.norm.pdf(x, mu, std)
    ax.plot(x, p, 'k', linewidth=2)
    title = "Fit results: mu = %.2f,  std = %.2f" % (mu, std)
    ax.set_title(title)
    plt.show()

    #Probability plot
def normal_test_QQplots(data, ax):
    #added ax argument for specifying subplot coordinate, 
    data.dropna(inplace=True)
    probplt = stats.probplot(data,dist='norm',fit=True,plot=ax)
    plt.show()

def normality_report(df):
    fig, axes = plt.subplots(nrows=len(df.columns), ncols=2,figsize=(12,50))
    ax_y = 0
    for col in df.columns[1:]:
        ax_x = 0
        normal_dist_hist(df[col], ax=axes[ax_y, ax_x])
        ax_x = 1
        normal_test_QQplots(df[col], ax=axes[ax_y, ax_x])
        ax_y  = 1

CodePudding user response:

Remove the plt.show() from your methods normal_dist_hist(...) and normal_test_QQplots(...). Add plt.show() at the end of your normality_report(...).

def normal_dist_hist(data, ax):
    ...
    plt.show() # Remove this

#Probability plot
def normal_test_QQplots(data, ax):
    ...
    plt.show() # Remove this

def normality_report(df):
    ...
    for col in df.columns[1:]:
        ax_x = 0
        normal_dist_hist(df[col], ax=axes[ax_y, ax_x])
        ax_x = 1
        normal_test_QQplots(df[col], ax=axes[ax_y, ax_x])
        ax_y  = 1
    plt.show() # Add it here.
  • Related