Home > Mobile >  Getting error while calculating AUC ROC for keras model predictions
Getting error while calculating AUC ROC for keras model predictions

Time:03-10

I have a patient data named dat and labels (0 = No Disease, 1 = Disease) named labl both in the form of array. I predicted my model and stored the predictions named pre which is also an array, and I want to calculate and plot the AUC ROC. But I am getting this error while doing so.

TypeError: Singleton array array(0., dtype=float32) cannot be considered a valid collection.

This is just a single patient record. But when I predict my model on more patients, I can easily calculate the AUC ROC. But I want to find that for one patient only.


>>> dat
array([[[114.6 ,  93.1 ,  37.17, 118.3 ,  64.3 ,  22.  ,  45.  ,   0.  ],
        [110.  ,  94.5 ,  37.3 , 136.  ,  59.  ,  17.5 ,  45.  ,   0.  ],
        [104.  ,  95.  ,  37.17, 154.  ,  74.  ,  26.  ,  45.  ,   0.  ],
        [106.  ,  94.  ,  37.17, 124.  ,  64.  ,  17.  ,  45.  ,   0.  ],
        [110.  ,  92.5 ,  37.17, 133.  ,  62.  ,  17.  ,  45.  ,   0.  ],
        [114.  ,  92.5 ,  36.7 , 127.  ,  62.  ,  21.  ,  45.  ,   0.  ],
        [106.  ,  95.  ,  37.17, 124.  ,  64.  ,  19.  ,  45.  ,   0.  ],
        [110.  ,  93.  ,  37.17, 138.  ,  70.  ,  17.  ,  45.  ,   0.  ],
        [114.  ,  90.  ,  37.17, 134.  ,  66.  ,  16.  ,  45.  ,   0.  ],
        [114.  ,  89.  ,  37.17, 116.  ,  60.  ,  20.  ,  45.  ,   0.  ],
        [120.  ,  91.  ,  37.17, 140.  ,  80.  ,  15.  ,  45.  ,   0.  ],
        [120.  ,  90.  ,  37.17, 122.  ,  72.  ,  15.  ,  45.  ,   0.  ],
        [120.  ,  92.  ,  37.17, 106.  ,  64.  ,  16.  ,  45.  ,   0.  ],
        [ 64.  ,  93.  ,  37.17, 100.  ,  53.  ,  20.  ,  45.  ,   0.  ],
        [128.  ,  95.  ,  37.17, 194.  ,  86.  ,  15.  ,  45.  ,   0.  ],
        [126.  ,  93.  ,  37.17,  34.  ,  30.  ,  27.  ,  45.  ,   0.  ],
        [124.  ,  94.5 ,  37.17,  80.  ,  59.  ,  35.  ,  45.  ,   0.  ],
        [127.  ,  97.  ,  37.5 , 102.  ,  69.  ,  35.  ,  45.  ,   0.  ],
        [130.  ,  97.  ,  37.17,  94.  ,  66.  ,  35.  ,  45.  ,   0.  ],
        [130.  ,  90.  ,  37.17,  90.  ,  62.  ,  35.  ,  45.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
        [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ]]],
      dtype=float32)
>>> labl
array([[[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.]]], dtype=float32)
>>> pre
array([[[0.24694729],
        [0.42795685],
        [0.5010372 ],
        [0.52086353],
        [0.52870005],
        [0.5377407 ],
        [0.5345124 ],
        [0.5310055 ],
        [0.531648  ],
        [0.5410067 ],
        [0.5446999 ],
        [0.5466636 ],
        [0.5504297 ],
        [0.5236943 ],
        [0.5244271 ],
        [0.5483868 ],
        [0.5533212 ],
        [0.5523378 ],
        [0.5553032 ],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267],
        [0.55902267]]], dtype=float32)

Using the below code I calculated the mortality as per time. But failed to calculate the AUC ROC.

# Figure out how many encounters we have
numencnt = dat.shape[0]

# Choose a random patient encounter to plot
ix = random.randint(0,numencnt-1)

# Create axis side by side
f, (ax1, ax2) = plt.subplots(2, 1)

# Plot the observation chart for the random patient encounter
ax1.pcolor(np.transpose(dat[ix,1:72,:]))
ax1.set_ylim(0,8)
plt.ylabel("mortality")
plt.xlabel("time/observation")

# Plot the patient survivability prediction
ax2.plot(pre[ix,1:72]);

The plot

plot

This is where I got the error:

from sklearn.metrics import roc_curve, auc

# get 0/1 binary label for each patient encounter
label = labl[:, 0, :].squeeze();

# get the last prediction in [0,1] for the patient
prediction = pre[:, -1, :].squeeze()

# compute ROC curve for predictions
rnn_roc = roc_curve(label,prediction)

# compute the area under the curve of prediction ROC
rnn_auc = auc(rnn_roc[0], rnn_roc[1])

--------------------------------------------------------------------------- TypeError Traceback (most recent call last) /tmp/ipykernel_129/3666067037.py in 8 9 # compute ROC curve for predictions ---> 10 rnn_roc = roc_curve(label,prediction) 11 12 # compute the area under the curve of prediction ROC

~/.conda/envs/default/lib/python3.9/site-packages/sklearn/metrics/_ranking.py in roc_curve(y_true, y_score, pos_label, sample_weight, drop_intermediate) 960 961 """ --> 962 fps, tps, thresholds = _binary_clf_curve( 963 y_true, y_score, pos_label=pos_label, sample_weight=sample_weight 964 )

~/.conda/envs/default/lib/python3.9/site-packages/sklearn/metrics/_ranking.py in _binary_clf_curve(y_true, y_score, pos_label, sample_weight) 731 raise ValueError("{0} format is not supported".format(y_type)) 732 --> 733 check_consistent_length(y_true, y_score, sample_weight) 734 y_true = column_or_1d(y_true) 735 y_score = column_or_1d(y_score)

~/.conda/envs/default/lib/python3.9/site-packages/sklearn/utils/validation.py in check_consistent_length(*arrays) 327 """ 328 --> 329 lengths = [_num_samples(X) for X in arrays if X is not None] 330 uniques = np.unique(lengths) 331 if len(uniques) > 1:

~/.conda/envs/default/lib/python3.9/site-packages/sklearn/utils/validation.py in (.0) 327 """ 328 --> 329 lengths = [_num_samples(X) for X in arrays if X is not None] 330 uniques = np.unique(lengths) 331 if len(uniques) > 1:

~/.conda/envs/default/lib/python3.9/site-packages/sklearn/utils/validation.py in _num_samples(x) 267 if hasattr(x, "shape") and x.shape is not None: 268 if len(x.shape) == 0: --> 269 raise TypeError( 270 "Singleton array %r cannot be considered a valid collection." % x 271 )

TypeError: Singleton array array(0., dtype=float32) cannot be considered a valid collection.

# plot rocs & display AUCs
plt.figure(figsize=(7, 5))
line_kwargs = {'linewidth': 4, 'alpha': 0.8}
plt.plot(rnn_roc[0], rnn_roc[1], label='LSTM: %0.3f' % rnn_auc, color='#6AA84F', **line_kwargs)
plt.legend(loc='lower right', fontsize=20)
plt.xlim((-0.05, 1.05))
plt.ylim((-0.05, 1.05))
plt.xticks([0, 0.25, 0.5, 0.75, 1.0], fontsize=14)
plt.yticks([0, 0.25, 0.5, 0.75, 1.0], fontsize=14)
plt.xlabel("False Positive Rate", fontsize=18)
plt.ylabel("True Positive Rate", fontsize=18)
plt.title("ROC Curve", fontsize=24)
plt.grid(alpha=0.25)
plt.tight_layout()

CodePudding user response:

The issue lies in your squeeze. You don't need to specify the index when using squeeze. squeeze flattens the array into 1D. If you pick [:,0,:], it's only 1 entry and hence the error.

Simply do

# get 0/1 binary label for each patient encounter
label = labl.squeeze();

# get the last prediction in [0,1] for the patient
prediction = pre.squeeze()
  • Related