Home > Mobile >  How to get distinct label name for each scatter points in graph using matplotlib in python?
How to get distinct label name for each scatter points in graph using matplotlib in python?

Time:07-14

I am using a for loop for giving labels & color to points. But I am not getting a unique label for each different color symbol, instead, I am getting a long queue at the top of my graph due to large excel data of 500 or so rows. Here is my code:

for val in ghi:
if val < 2:
    col.append('navy')
    lab.append(' <2 ')
elif (val >= 2) and (val < 4):
    col.append('lightblue')
    lab.append(' 2~4 ')
elif (val >= 4) and (val < 6):
    col.append('orange')
    lab.append(' 4~6 ')
else:
    col.append('brown')
    lab.append(' >6 ')
plt.scatter(x, y, marker = "D", s = 10, color = col, label = lab)
sns.lineplot(x, y, color ="red")
plt.legend()

CodePudding user response:

The reason you are seeing multiple labels is because of label=lab. There is one entry for each row. So, you will need to use plt.legend() to add the handles and labels, so that you can have just four entries as in the if/else statement. Also, to ensure the colors and labels match, sort the data by the lab field.

As there wasn't any data provided the first few lines include creation of random data. Note that I have the data into a dataframe and sorted it by ghi field. Later, the legend has the handles and labels that you can customize as you need. Hope this is what you are trying to achieve.

Otherwise, I have tried to keep the same names and as much of your code as is, so that it is easier for you to interpret.

import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from matplotlib.lines import Line2D 

ghi = np.random.randint(11, size=(20))
x = np.random.rand(20)
y = np.random.rand(20)
col = []
for val in ghi:
    if val < 2:
        col.append('navy')
    elif (val >= 2) and (val < 4):
        col.append('lightblue')
    elif (val >= 4) and (val < 6):
        col.append('orange')
    else:
        col.append('brown')

df = pd.DataFrame({'x':x, 'y':y, 'ghi':ghi, 'col':col})
df.sort_values(by=['ghi'], inplace = True)
plt.scatter(df.x, df.y, marker = "D", s = 50, color = df.col)#, label = np.unique(lab))
sns.lineplot(x, y, color ="red")
myHandles = [Line2D([], [], marker='D', color='navy', linestyle='None'),
          Line2D([], [], marker='D', color='lightblue', linestyle='None'),
          Line2D([], [], marker='D', color='orange', linestyle='None'),
          Line2D([], [], marker='D', color='brown', linestyle='None')]

plt.legend(handles=myHandles, labels = [' <2 ', ' 2~4 ', ' 4~6 ', ' >6 '])

Output plot

enter image description here

  • Related