Home > Software engineering >  How to generate proper legends for scatter plot in python
How to generate proper legends for scatter plot in python

Time:08-24

I am trying to prepare a box and scatter plot for 8 data points in python. I use the following code:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

x = [24.4, 6.7, 19.7, 16.0, 25.1, 19.5, 10, 22.1]

f, ax = plt.subplots()
ax.boxplot(x, vert=False, showmeans=True, showfliers=False)
x0 = np.random.normal(1, 0.05, len(x))
c = ['r', 'b', 'c', 'm', 'y', 'g', 'm', 'k']
lab = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']
ax.scatter(x, x0, c=c, s=60, alpha=0.2)
ax.legend(labels=lab, loc="upper left", ncol=8)

It generate a image like the following: enter image description here

It looks that the legend doesn't have the proper sphere symbols with different colors, which I expected. Beside the colors for the symbols are shallow and light.

So how to generate proper legends with correct symbols and how to make the colors of the symbols brighter and sharper?

I will deeply appreciate it if anyone can help.

Best regards

CodePudding user response:

To make the colours brighter, just raise the alpha value. For the legend, the order of the plotting matters here, it is better that the boxplot is plotted after the scatter plots. Also, to get for each point a place in the legend, it should b considered as a different graph, for that I used a loop to loop over the values of x, x0 and c. Here's the outcome:

import numpy as np
import matplotlib.pyplot as plt

# init figure
f, ax = plt.subplots()

# values
x = [24.4, 6.7, 19.7, 16.0, 25.1, 19.5, 10, 22.1]
x0 = np.random.normal(1, 0.05, len(x))

# labels and colours
c = ['r', 'b', 'c', 'm', 'y', 'g', 'm', 'k']
lab = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']

# put the plots into a list
plots = []
for i in range(len(x)):
    p = ax.scatter(x[i], x0[i], c=c[i], s=60, alpha=0.5) # raised the alpha to get sharper colors
    plots.append(p)

# plot legends
plt.legend(plots,
           labels=lab,
           scatterpoints=1,
           loc='upper left',
           ncol=8,
           fontsize=8)

# plot the box plot (the order here matters!)
ax.boxplot(x, vert=False, showmeans=True, showfliers=False)

# save the desired figure
plt.savefig('tt.png')

Output:

plot img

  • Related