Home > OS >  Multi-index for looping sub plotting data in python
Multi-index for looping sub plotting data in python

Time:09-26

I am trying to compare the yearly P/E of select stocks with their industry mean (P/E_y).

I feel like the best way to analyze this is to visualize the data with line subplots or individual line graphs.

Original DataFrame:

Input:
stock_merge = pd.read_csv("industry mean.csv")
stock_merge

Output:
    Ticker  Year    Industry            P/E_x       P/E_y
0   NVDA    2019    Semiconductors      20.616292   15.79
1   NVDA    2020    Semiconductors      53.349938   15.79
2   NVDA    2021    Semiconductors      76.028282   15.79
3   NVDA    2022    Semiconductors      62.528408   15.79
4   AVGO    2018    Semiconductors      6.287096    15.79
5   AVGO    2019    Semiconductors      40.731857   15.79
6   AVGO    2020    Semiconductors      45.212246   15.79
7   AVGO    2021    Semiconductors      30.819690   15.79
...                ...              ...         ...
400 EFX     2018    Consulting Services 35.487911   35.56
401 EFX     2019    Consulting Services -43.694808  35.56
402 EFX     2020    Consulting Services 44.853370   35.56
403 EFX     2021    Consulting Services 47.910847   35.56

I tried using .groupby() to loop into each industry from the dataframe and then plotting.

all = stock_merge.groupby(['Industry', 'Ticker', 'Year']).mean()
all

Output

                                            P/E_x       P/E_y
Industry                 Ticker   Year      
Aerospace & Defense      BA       2018      17.606935   26.44
                                  2019     -299.239806  26.44
                                  2020     -10.595709   26.44
                                  2021     -28.156965   26.44
                         HII      2018      8.511068    26.44
...                     ...                 ...         ...

Travel Services          RCL      2021     -3.724618    62.84
Trucking                 ODFL     2018     15.484180    12.30
                                  2019     23.518331    12.30
                                  2020     33.315306    12.30
                                  2021     39.847872    12.30

This is what I tried:

all_industries = all['Industry'].unique()
feature = enumerate(all_industries)

plt.figure(figsize = (30,20))

for i in enumerate(feature):

  plt.subplot(6, 3, i[0] 1)
  sns.lineplot(x='P/E_y', y=i[1], hue = 'Ticker', data=all)

I received nothing but errors and empty subplots.

This is what I'm trying to achieve:

  1. There should be a new plot for each industry (preferably subplots)
  2. In that plot, should contain each ticker, years, and P/E for that industry

For example, if there are 5 stocks in the Semiconductors industry, that line graph should show 6 lines: 5 lines for every stock in Semiconductors, P/E_x between 2018-2022 and 1 line for P/E_y.

How do I plot this data?

CodePudding user response:

I think you are going down the right track by using .groupby().

You may find it easier if you iterate through the grouped object using the .groups property, which returns a dictionary of groups such as

{'Consulting Services': [8, 9], 'Semiconductors': [0, 1, 2, 3, 4, 5, 6, 7]}

By iterating over the keys of this dictionary, I think you will be able to achieve what you want. In each loop you can further divide the group into your individual lines. Hopefully the following snippet will help to put you on the right track:

stock_merge = pd.read_csv("data.csv")
industries = stock_merge.groupby("Industry")

for k in industries.groups:
    industry = industries.get_group(k)
    p_ey = industry["P/E_y"].iat[0]
    print(k, p_ey)
    # Create subplot, add P/E_y line 
    tickers = industry.groupby("Ticker")
    for t in tickers.groups:
        ticker_data = tickers.get_group(t)
        # Use ticker_data to generate an individual line on the plot
        display(ticker_data)
  • Related