Home > other >  Does keras automatically weight classes?
Does keras automatically weight classes?

Time:08-27

The Keras documentation states:

Using sample weighting and class weighting

With the default settings the weight of a sample is decided by its frequency in the dataset. There are two methods to weight the data, independent of sample frequency:
-Class weights
-Sample weights

I have a significantly imbalanced dataset. I was looking at how to adjust for this, and came across a few answers here dealing with that, such as here and here. My plan was thus to create a dictionary object of the relative frequencies and pass them onto model.fit()'s class_weight parameter.

However, now that I'm looking up the documentation, it seems as though class imbalance is already dealt with? So I don't necessarily have to manage for the different class counts after all?

For the record, here are the class counts:
0: 25,811, 1: 2,444, 2: 5,293, 3: 874, 4: 709.

And here is the dictionary I was going to pass onto (pseudocode):

class_weight = {0: 1.0,
                1: len(/0/)/len(/1/),
                2: len(/0/)/len(/2/),
                3: len(/0/)/len(/3/),
                4: len(/0/)/len(/4/)}

CodePudding user response:

In the Model() class documentation, the model.fit() has the following signature:

fit(
    x=None,
    y=None,
    batch_size=None,
    epochs=1,
    verbose='auto',
    callbacks=None,
    validation_split=0.0,
    validation_data=None,
    shuffle=True,
    class_weight=None,
    sample_weight=None,
    initial_epoch=0,
    steps_per_epoch=None,
    validation_steps=None,
    validation_batch_size=None,
    validation_freq=1,
    max_queue_size=10,
    workers=1,
    use_multiprocessing=False
)

You can clearly notice that the class_weight parameter is None, by default.

The documentation also mentions that this is an optional parameter:

Optional dictionary mapping class indices (integers) to a weight (float) value, used for weighting the loss function (during training only). This can be useful to tell the model to "pay more attention" to samples from an under-represented class.

Example (a mistake for class_1 is penalized 10 times more than a mistake for class_0):

class_weight = {0:1,
                1:10}

PS: The statements are indeed a little bit misleading, in the sense that one can infer that if you have 2 classes with 90 and 10 samples respectively, the "weight" of the class is its number, but in fact what the explanation intends to convey is a synonym to saying "the model will not prioritize class 2 with 10 points over class 1 with 90 points, it's the overrepresentation of the latter (basic frequency) which counts more".

In other words, what it tells you is that the basic cross-entropy loss function (be it binary or multiclass) will favor the overrepresented class in absence of specific constraints/parameters from the developer.

This is indeed a correct affirmation, hence the necessity of tackling the imbalance via this class_weighting scheme in this situation.

  • Related