Home > Blockchain >  choosing artists/handles for legend when plotting from xarray.dataArray
choosing artists/handles for legend when plotting from xarray.dataArray

Time:02-18

I am plotting a facet plot from a xarray DataArray with multiple lines on each plot. I want to have a legend for each plot but it should only contain certain lines. I use .legend() where the data I want to have in the legend should be given but as I am plotting from an xarray DataArray I don't know how to do that.

Here as an example code:

import matplotlib.pyplot as plt
import xarray as xr
import numpy as np
import pandas as pd
           
data1 = np.random.randn(4, 4,3)
loc= np.linspace(1,3,3)
type = ["a", "b", "c", "d"]
            
times = pd.date_range("2000-01-01", periods=4)
foo = xr.DataArray(data1, coords=[times, type, loc], dims=["time", "type","loc"])
    
t=foo.plot.line(x="time", col="loc", linewidth= 5, col_wrap=3)
for i, ax in enumerate(t.axes.flat):
    ax.legend(('labelc','labeld'))

enter image description here

Here, I would want the labels to actually fit the data c and d

ax.legend is supposed to work like this:

ax.legend([line1, line2, line3], ['label1', 'label2', 'label3'])

I have already tried it like in the following but could not get it to work:

for i, ax in enumerate(t.axes.flat):
    ax.legend(foo[:,2:,i],('labelc','labeld'))

CodePudding user response:

We can retrieve the figure-level legend entries and filter them for the desired entries:

import matplotlib.pyplot as plt
import xarray as xr
import numpy as np
import pandas as pd
           
data1 = np.random.randn(4, 4, 3)
locs = np.linspace(1, 3, 3)
types = ["a", "b", "c", "d"]
            
times = pd.date_range("2000-01-01", periods=4)
foo = xr.DataArray(data1, coords=[times, types, locs], dims=["time", "type","loc"])
    
t=foo.plot.line(x="time", col="loc", linewidth= 5, col_wrap=3)

#define list which entries to show in subplot legends 
to_keep = ["b", "d"]
#retrieve global handles and labels
labels = t._hue_var.to_numpy()
handles = t._mappables[-1]

#create a dictionary of labels and handles while filtering unwanted entries
label_dic = {l:h  for l, h in zip(labels, handles) if l in to_keep}

#create legend entries as you intended with your code
for ax in t.axes.flat:
    ax.legend(label_dic.values(), label_dic.keys(), loc="best")
plt.show()

Sample output: enter image description here

The code to retrieve the legend entries is lifted from the source code for def add_legend().

Please also note that I have changed your variable type to types as it shadowed the Python function type().

  • Related