Home > Mobile >  How to get rid of the double printed axis in matpotlib subplot
How to get rid of the double printed axis in matpotlib subplot


I'm trying to print the top items in subplots but with the code I used I get double printed axis how can I prevent that from happening thanks for your help

below you can see the code and resulting graph

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

Cat = ['A', 'B', 'C']
It = ['D', 'E', 'F','G','H','I','J']
df = pd.DataFrame({'Category': np.random.choice(Cat, n ),
                   'Item': np.random.choice(It, n ),
                  'Net sales':np.random.randint(100,500,(n)),
                      'Date':np.random.choice( pd.date_range('1/1/2021', periods=365,
                          freq='D'), n, replace=False)})
# Grouping products by sales
prod_sales = pd.DataFrame(df.groupby('Item').sum()['Net sales'])

# Sorting the dataframe in descending order
prod_sales.sort_values(by=['Net sales'], inplace=True, ascending=False)

# fig, ax = plt.subplots(figsize=(20,10))
fig, ax = plt.subplots(5,2,figsize=(20,10))
for section, group in df.groupby('Item'):
    if any(item in section for item in prod_sales[:4].index):
        i=i 1
        ax = fig.add_subplot(2, 2, i)
        group.plot(x='Date', y='Net sales', ax=ax, label=section)

enter image description here

CodePudding user response:

This is my solution, not the most tidy, many rows, but does the job. Could be further improved or shortened:

Tot = df['Item'].nunique() # number of sublots
Cols = 3 # number of columns in the subplot

Rows = Tot // Cols 
Rows  = Tot % Cols
Position = range(1,Tot   1)

# Create main figure
df = df.sort_values('Date')
fig = plt.figure(1, figsize = (20, 8))
for k, item in zip(range(Tot), df['Item'].unique()):
    ax = fig.add_subplot(Rows, Cols, Position[k])
    ax.plot(df[df['Item'] == item]['Date'], df[df['Item'] == item]['Net sales'])


enter image description here

CodePudding user response:

enter image description hereI worked on the input above and used the code below On my real data and it works for me thanks for all the input

df=concatenated_df.groupby('Category').resample('M', label='right',closed='left'
                                           , on='Date').sum().reset_index().sort_values(by='Date')

top_cat=df.groupby(['Category']).sum()['Net sales'].nlargest(30)

fig, axs = plt.subplots(round(n/2),2,figsize=(20,n*2));
fig.subplots_adjust(hspace = 0.5, wspace=0.1);
axs = axs.ravel();

# for ax, (section, group) in zip(axs, df.groupby('Category')):    
for ax, (section, group) in zip(axs, df[df.Category.isin(top_cat[:n].index)].groupby('Category')):    
    group.plot(x='Date', y='Net sales', ax=ax, label=section)
    ax.set_title(section )
    ax.set_ylim([0, 120000])

  • Related