The Keras documentation states:
Using sample weighting and class weighting
With the default settings the weight of a sample is decided by its frequency in the dataset. There are two methods to weight the data, independent of sample frequency:
-Class weights
-Sample weights
I have a significantly imbalanced dataset. I was looking at how to adjust for this, and came across a few answers here dealing with that, such as here and here. My plan was thus to create a dictionary object of the relative frequencies and pass them onto model.fit()
's class_weight
parameter.
However, now that I'm looking up the documentation, it seems as though class imbalance is already dealt with? So I don't necessarily have to manage for the different class counts after all?
For the record, here are the class counts:
0
: 25,811, 1
: 2,444, 2
: 5,293, 3
: 874, 4
: 709.
And here is the dictionary I was going to pass onto (pseudocode):
class_weight = {0: 1.0,
1: len(/0/)/len(/1/),
2: len(/0/)/len(/2/),
3: len(/0/)/len(/3/),
4: len(/0/)/len(/4/)}
CodePudding user response:
In the Model()
class documentation, the model.fit()
has the following signature:
fit(
x=None,
y=None,
batch_size=None,
epochs=1,
verbose='auto',
callbacks=None,
validation_split=0.0,
validation_data=None,
shuffle=True,
class_weight=None,
sample_weight=None,
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None,
validation_batch_size=None,
validation_freq=1,
max_queue_size=10,
workers=1,
use_multiprocessing=False
)
You can clearly notice that the class_weight
parameter is None
, by default.
The documentation also mentions that this is an optional parameter:
Optional dictionary mapping class indices (integers) to a weight (float) value, used for weighting the loss function (during training only). This can be useful to tell the model to "pay more attention" to samples from an under-represented class.
Example (a mistake for class_1
is penalized 10
times more than a mistake for class_0
):
class_weight = {0:1,
1:10}
PS: The statements are indeed a little bit misleading, in the sense that one can infer that if you have 2 classes with 90 and 10 samples respectively, the "weight" of the class is its number, but in fact what the explanation intends to convey is a synonym to saying "the model will not prioritize class 2 with 10 points over class 1 with 90 points, it's the overrepresentation of the latter (basic frequency) which counts more".
In other words, what it tells you is that the basic cross-entropy loss function (be it binary or multiclass) will favor the overrepresented class in absence of specific constraints/parameters from the developer.
This is indeed a correct affirmation, hence the necessity of tackling the imbalance via this class_weighting scheme in this situation.