Home > Software design >  Python adjust scale intensity at confusion matrix plot
Python adjust scale intensity at confusion matrix plot

Time:12-12

I have a number of confusion matrix plots with numbers not summing up to the same sum (the numbers are out of 100 for a benchmark) Please see attached example image below: confusion matrix I do not want the 22 and the 32 have the same color intensity, but be at the same scale from 0 to 100 intensity levels. How can I adjust the scale in python given the following used code:

import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

cm = confusion_matrix(y_true, y_pred, labels=["Up", "Down"])
disp = ConfusionMatrixDisplay(confusion_matrix=cm,display_labels=["Up", "Down"])
disp.plot(cmap="OrRd")

CodePudding user response:

disp.ax_.get_images()[0].set_clim(0, 100)

And the full code:

import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

cm = confusion_matrix(y_true, y_pred, labels=["Up", "Down"])

disp = ConfusionMatrixDisplay(confusion_matrix=cm,display_labels=["Up", "Down"])
disp.plot(cmap="OrRd")

disp.ax_.get_images()[0].set_clim(0, 100)

CodePudding user response:

Try this...

# Get the underlying axes object of the confusion matrix display
ax = disp.ax_

# Plot the confusion matrix using the matshow() method, setting the color intensity scale to range from 0 to 100
im = ax.matshow(cm, cmap="OrRd", vmin=0, vmax=100, norm=matplotlib.colors.LogNorm())

# Add a color bar to the plot
ax.colorbar(im)

# Show the plot
plt.show()

This will plot the confusion matrix using the 'OrRd' colormap, with a color intensity scale that ranges from 0 to 100 and a logarithmic normalization. It will also add a color bar to the plot, showing the mapping from values in the confusion matrix to colors in the colormap.

Note that you must import the matplotlib.colors module in order to use the LogNorm normalization, as shown in the code above:

import matplotlib.colors
  • Related