This is an example of code to plot a confusione matrix for a multi classification problem.
cf_matrix = np.array([[50, 2, 38],
[7, 43, 32],
[1, 0, 4]])
labels = ['col1','col2','col3']
df_confusion = pd.DataFrame(cf_matrix, index = labels, columns=labels)
df_confusion['TOTAL'] = df_confusion.sum(axis=1)
df_confusion.loc['TOTAL']= df_confusion.sum()
plt.figure(figsize=(24, 10))
sns.set(font_scale = 1.5)
ax = sns.heatmap(df_confusion, annot=True, cmap='Blues', fmt="d")
ax.set_title('Confusion Matrix\n\n',size=22)
ax.set_xlabel('\nPredicted Values',size=20)
ax.set_ylabel('Actual Values ', size=20)
plt.show()
How can I change the colorbar so that the color is not related to the number of elements, but instead is based on the percentage of elements for each cell divided by the total real elements for that class (row). For example, the third class col3 in this case will have the highest color because it has 4/5 = 80% of correct prediction respect to col1 and col2 that have respectively: 50/90 = 55% and 43/82 = 52% of correct predicions.
CodePudding user response:
Since cmap
parameter uses data
to apply the gradients, you need to change data
to percentages, and then use annot
parameter to overwrite the values with the actual numbers.
So, I think you want something like below. N.B. I've set the percentages for df_percentages.TOTAL
to 0
below; otherwise the column TOTAL
will obviously turn dark blue completely.
Anyway, now that you know the logic, I trust you'll know how to adjust the values of df_percentages
to your liking.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
cf_matrix = np.array([[50, 2, 38],
[7, 43, 32],
[1, 0, 4]])
labels = ['col1','col2','col3']
df_confusion = pd.DataFrame(cf_matrix, index = labels, columns=labels)
df_confusion['TOTAL'] = df_confusion.sum(axis=1)
df_confusion.loc['TOTAL']= df_confusion.sum()
# get percentages
df_percentages = df_confusion.div(df_confusion.TOTAL, axis=0)
df_percentages.TOTAL = 0
# =============================================================================
# col1 col2 col3 TOTAL
# col1 0.555556 0.022222 0.422222 0
# col2 0.085366 0.524390 0.390244 0
# col3 0.200000 0.000000 0.800000 0
# TOTAL 0.327684 0.254237 0.418079 0
# =============================================================================
plt.figure(figsize=(24, 10))
sns.set(font_scale = 1.5)
# cmap using data for color, taking values from annot
ax = sns.heatmap(data=df_percentages, annot=df_confusion, cmap='Blues', fmt="d",
cbar_kws={'label': 'percentages'})
ax.set_title('Confusion Matrix\n\n',size=22)
ax.set_xlabel('\nPredicted Values',size=20)
ax.set_ylabel('Actual Values ', size=20)
plt.show()
Result: