Home > Software engineering >  Keras class_weight error dictionary keys/values
Keras class_weight error dictionary keys/values

Time:10-05

I am having an issue adding class_weight to a Keras model. I have manually performed the calculation of the weights and created a dictionary with them to pass on the model.fit as per below:

 model.fit(train_dataset,
    steps_per_epoch=train_steps,
    validation_data=valid_dataset,
    validation_steps=valid_steps,
    epochs=epochs,
    callbacks=callbacks,
    class_weight={'0': 0.12546960479781682, '1': 1115.3019958392365, '2': 3032.7837992640307,
                        '3': 12372.961843790014, '4': 7941.064776579353, '5': 4929.219609860191,
                        '6': 79056.03926915735, '7': 35.34368125856832})

However, it throws the following error which I can not figure out. I am running TF 2.6

ValueError: Expected `class_weight` to be a dict with keys from 0 to one less than the number of classes, found {'0': 0.12546960479781682, '1': 1115.3019958392365, '2': 3032.7837992640307, '3': 12372.961843790014, '4': 7941.064776579353, '5': 4929.219609860191, '6': 79056.03926915735, '7': 35.34368125856832}

The model has 8 classes so I can not see what is going wrong. Any advice is very welcome.

Thanks!

CodePudding user response:

you can check this document, Train the model with class_weight argument part:

It show that the class_weight is a dict with keys is number type, not str.

CodePudding user response:

class_weight is a dictionary of the form {class_index(an integer:weight (a float)} In your code you have the key as a string, you need to convert it to an integer. One thing to be careful about is to ensure the key value (an integer representing the class index ) is correctly associated with the class list. So for example if your class list is ['apple','peach', 'plum'] the class index for class apple is 0, for peach it is 1 and for plum it is 2. When you use generator like

gen=ImageDataGenerator.flow_from_directory 

the class names are derived from the names of the class directories and the order of the classes in the class list is alphanumeric so the class list might not be in the order you think it would be. To be safe use

class_dict= gen.class_indices

if you print it out it is dictionary of the form {class_index:class name} For the above example it would be

{0:'apple', 1:'peach', 2:'plum'}
  • Related