Home > Software design >  Ensure that pandas.crosstab returns a square matrix
Ensure that pandas.crosstab returns a square matrix

Time:07-07

I am currently using pandas.crosstab to generate the confusion matrix of my classifiers after testing. Unfortunately, sometimes my classifier fails, and classifies every signal as a single label (instead of multiple labels). pandas.crosstab generates a single vector (or a non-square matrix) in that case instead of a square matrix.
As example, my ground truth would be

true_data = pandas.Series([1, 1, 2, 2, 3, 3, 4, 4, 5, 5])

and my predicted data is

pred_data = pandas.Series([3, 3, 2, 3, 2, 1, 1, 3, 4, 1])

Applying pandas.crosstab(true_data, pred_data, dropna=False) gives

col_0  1  2  3  4
row_0
1      0  0  2  0
2      0  1  1  0
3      1  1  0  0
4      1  0  1  0
5      1  0  0  1

Is there a way to get

col_0  1  2  3  4  5
row_0
1      0  0  2  0  0
2      0  1  1  0  0
3      1  1  0  0  0
4      1  0  1  0  0
5      1  0  0  1  0

instead, i.e. leaving the matrix square and filling the missing labels with 0?

CodePudding user response:

After calculating crosstab you can reindex the dataframe along both index and columns axis.

i = df.index.union(df.columns)
df.reindex(index=i, columns=i, fill_value=0)

   1  2  3  4  5
1  0  0  2  0  0
2  0  1  1  0  0
3  1  1  0  0  0
4  1  0  1  0  0
5  1  0  0  1  0

CodePudding user response:

You could create a zeros array of the required shape and then replace a portion of the array with the crosstab

xtab = pd.crosstab(pred_data, true_data, dropna=False).sort_index(axis=0).sort_index(axis=1)
all_unique_values = sorted(set(true_data) | set(pred_data))
z = np.zeros((len(all_unique_values), len(all_unique_values)))
rows, cols = xtab.shape
z[:rows, :cols] = xtab
square_xtab  = pd.DataFrame(z, columns=all_unique_values, index=all_unique_values) 

Output

     1    2    3    4    5
1  0.0  0.0  1.0  1.0  1.0
2  0.0  1.0  1.0  0.0  0.0
3  2.0  1.0  0.0  1.0  0.0
4  0.0  0.0  0.0  0.0  1.0
5  0.0  0.0  0.0  0.0  0.0

I haven't thought / tested yet if this approach will work if the mismatch is in the "middle" - as in, if, e.g., pred_data = [1, 2, 4, 5] and true_data = [1, 2, 3, 4]

  • Related