I trained a CNN model with 6 different classes (labels are 0-5) and I am getting more than 90% accuracy out of it. It can correctly classify the classes. I am actually trying to detect anomaly with it. So what I want is, if any data comes which my model has never seen before or never been trained on similar data then it will be classified as anomaly. I do not have any abnormal data to train my model, I just have the normal data. So the rule would be, if any incoming data point does not belong to any of the six classes then it is anomaly. How can I do it?
I thought of a method which I am not sure if it works in this scenario. The method is, when I predict a single data point it gives me the probability score for all 6 classes. So, I take the maximum value out of this 6 value and if this max value is below a threshold level, for example, 70, then this observation will be classified as anomaly. That means, if any data point has less than 70% probability of being one of the six classes then it is an anomaly. The code looks like this
y_pred = s_model.predict(X_test_scaled)
normal = []
abnormal = []
max_value_list= []
for i in y_pred:
max_value= np.max(i)
max_value_list.append(max_value)
if max_value <=0.70:
abnormal.append(max_value)
print('Anomaly detected')
else:
normal.append(max_value)
print('The number of total abnormal observations are: ',len(abnormal))
Does this method works in my case? Or is there any better way to do it? Any kind of help is appreciated.
CodePudding user response:
Interesting problem but I think your method does not work.
When your model's entropy is high, i.e. it is unsure which class to choose for that particular sample input, it does not necessarily mean that that sample is an anomaly, it just means that the model is perhaps struggling to select the correct normal class.
I suggest adding some abnormal samples (some random unrelated images, if your samples are images), between 1% to 10% of your data, and labelling them as class 7
. Then train your model with those (and perhaps give more penalty for misclassifying the class 7
).
When you have your unseen samples, you classify them using your trained model. If they are classified as class 7
, then you know they are anomalies.
Hope this helps.