Tensorflow text_classification:
https://www.tensorflow.org/tutorials/keras/text_classification
There are only two classes in this text_classification example,
Label 0 corresponds to neg
Label 1 corresponds to pos
But the following prediction values neither 0
nor 1
:
examples = [
"The movie was great!",
"The movie was okay.",
"The movie was terrible..."
]
export_model.predict(examples)
array([[0.5921171 ],
[0.41369876],
[0.33293992]], dtype=float32)
The result of "The movie was great!
" is 0.5921171
, but what does that mean?
Does it mean positive
for value >= 0.5
and negative
for value < 0.5
?
It should predict it as something like:
array([[1 ],
[1],
[0]], dtype=float32)
At the end of that link, there is an exercise:
Exercise: multi-class classification on Stack Overflow questions
The threshold 0.5 is not work for multi-class classification on Stack Overflow questions because there are total 4 labels & classes in this exercise.
for i in range(len(raw_train_ds.class_names)):
print("Label: {}, Class Name: {}".format(i, raw_train_ds.class_names[i]))
Label: 0, Class Name: csharp
Label: 1, Class Name: java
Label: 2, Class Name: javascript
Label: 3, Class Name: python
I'm using the same examples array for prediction:
examples = [
"The movie was great!",
"The movie was okay.",
"The movie was terrible..."
]
export_model.predict(examples)
array([[0.52356344, 0.4763114 , 0.54468685, 0.4438951 ],
[0.52287084, 0.48242405, 0.5407451 , 0.4425373 ],
[0.5221944 , 0.4766879 , 0.5448719 , 0.4454918 ]], dtype=float32)
How to set threshold for the four
multi-class classification
on Stack Overflow questions?
I think it is not related to threshold.
CodePudding user response:
You should define a threshold that whenever you get a value greater than it it is considered as positive, otherwise it considers it as negative. In your example to get [1,1,0] a threshold of 0.4 for example gives the right predictions.