I'm trying to plot a confusion matrix that has one decimal for all values but if the value is exact zero, I would like to keep it as an exact zero rather than 0.0. How can I achieve this? In this minimal example for example I would like to keep 4.0
but instead of 0.0
I would like a 0
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
y_true = [0, 0, 0, 0, 1, 1, 1, 1]
y_pred = [0, 0, 0, 0, 1, 1, 1, 1]
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix(y_true, y_pred),
display_labels=[0, 1])
disp.plot(include_values=True, cmap='Blues', xticks_rotation='vertical',
values_format='.1f', ax=None, colorbar=True)
plt.show()
I tried to pass a custom function
def custom_formatter(value):
if value == 0:
return '{:.0f}'.format(value)
else:
return '{:.1f}'.format(value)
to
disp.plot(include_values=True, cmap='Blues', xticks_rotation='vertical',
values_format=values_format, ax=None, colorbar=True)
but this gives the error
TypeError: format() argument 2 must be str, not function
Thank you so much!
CodePudding user response:
I think the easiest way is looping check/replace text:
for iy, ix in np.ndindex(disp.text_.shape):
txt = disp.text_[iy, ix]
if txt.get_text() == "0.0":
txt.set_text("0")