Home > Mobile >  How to get rid of the double printed axis in matpotlib subplot
How to get rid of the double printed axis in matpotlib subplot

Time:10-23

I'm trying to print the top items in subplots but with the code I used I get double printed axis how can I prevent that from happening thanks for your help

below you can see the code and resulting graph

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

Cat = ['A', 'B', 'C']
It = ['D', 'E', 'F','G','H','I','J']
n=365
df = pd.DataFrame({'Category': np.random.choice(Cat, n ),
                   'Item': np.random.choice(It, n ),
                  'Net sales':np.random.randint(100,500,(n)),
                      'Date':np.random.choice( pd.date_range('1/1/2021', periods=365,
                          freq='D'), n, replace=False)})
# Grouping products by sales
prod_sales = pd.DataFrame(df.groupby('Item').sum()['Net sales'])

# Sorting the dataframe in descending order
prod_sales.sort_values(by=['Net sales'], inplace=True, ascending=False)

# fig, ax = plt.subplots(figsize=(20,10))
fig, ax = plt.subplots(5,2,figsize=(20,10))
i=0
for section, group in df.groupby('Item'):
    if any(item in section for item in prod_sales[:4].index):
        i=i 1
        ax = fig.add_subplot(2, 2, i)
        group.plot(x='Date', y='Net sales', ax=ax, label=section)

enter image description here

CodePudding user response:

This is my solution, not the most tidy, many rows, but does the job. Could be further improved or shortened:

Tot = df['Item'].nunique() # number of sublots
Cols = 3 # number of columns in the subplot

Rows = Tot // Cols 
Rows  = Tot % Cols
Position = range(1,Tot   1)

# Create main figure
df = df.sort_values('Date')
fig = plt.figure(1, figsize = (20, 8))
for k, item in zip(range(Tot), df['Item'].unique()):
    ax = fig.add_subplot(Rows, Cols, Position[k])
    ax.set_title(item)
    ax.plot(df[df['Item'] == item]['Date'], df[df['Item'] == item]['Net sales'])

plt.show()

enter image description here

CodePudding user response:

enter image description hereI worked on the input above and used the code below On my real data and it works for me thanks for all the input

df=concatenated_df.groupby('Category').resample('M', label='right',closed='left'
                                           , on='Date').sum().reset_index().sort_values(by='Date')

top_cat=df.groupby(['Category']).sum()['Net sales'].nlargest(30)

n=10
fig, axs = plt.subplots(round(n/2),2,figsize=(20,n*2));
fig.subplots_adjust(hspace = 0.5, wspace=0.1);
axs = axs.ravel();

# for ax, (section, group) in zip(axs, df.groupby('Category')):    
for ax, (section, group) in zip(axs, df[df.Category.isin(top_cat[:n].index)].groupby('Category')):    
    group.plot(x='Date', y='Net sales', ax=ax, label=section)
    ax.set_title(section )
    ax.set_ylim([0, 120000])

  • Related