Home > Mobile >  What does the result numbers mean in Tensorflow text_classification
What does the result numbers mean in Tensorflow text_classification

Time:02-24

Tensorflow text_classification:

https://www.tensorflow.org/tutorials/keras/text_classification

There are only two classes in this text_classification example,

Label 0 corresponds to neg
Label 1 corresponds to pos

But the following prediction values neither 0 nor 1:

examples = [
  "The movie was great!",
  "The movie was okay.",
  "The movie was terrible..."
]

export_model.predict(examples)

array([[0.5921171 ],
       [0.41369876],
       [0.33293992]], dtype=float32)
   

The result of "The movie was great!" is 0.5921171, but what does that mean?

Does it mean positive for value >= 0.5 and negative for value < 0.5?

It should predict it as something like:

array([[1 ],
       [1],
       [0]], dtype=float32)

At the end of that link, there is an exercise:

Exercise: multi-class classification on Stack Overflow questions

The threshold 0.5 is not work for multi-class classification on Stack Overflow questions because there are total 4 labels & classes in this exercise.

for i in range(len(raw_train_ds.class_names)):
    print("Label: {}, Class Name: {}".format(i, raw_train_ds.class_names[i]))

Label: 0, Class Name: csharp
Label: 1, Class Name: java
Label: 2, Class Name: javascript
Label: 3, Class Name: python

I'm using the same examples array for prediction:

examples = [
  "The movie was great!",
  "The movie was okay.",
  "The movie was terrible..."
]

export_model.predict(examples)

array([[0.52356344, 0.4763114 , 0.54468685, 0.4438951 ],
       [0.52287084, 0.48242405, 0.5407451 , 0.4425373 ],
       [0.5221944 , 0.4766879 , 0.5448719 , 0.4454918 ]], dtype=float32)

How to set threshold for the four multi-class classification on Stack Overflow questions?

I think it is not related to threshold.

CodePudding user response:

You should define a threshold that whenever you get a value greater than it it is considered as positive, otherwise it considers it as negative. In your example to get [1,1,0] a threshold of 0.4 for example gives the right predictions.

  • Related