Hi I was trying to create an image segmentation model using keras and a U-NET structure.
In my test I saw that the model tend to not distinguish very well grass and trees.
I thought I could use sample_weight since I was working with a tensorflow dataset.
So I gave each pixel of each image a weight (so I incremented grass weight and decremented the tree ones).
Is this the right thing to do? I saw that sample_weight is used to gave each sample a weight so I don't know if what I did was correct.
This is the function I used:
@tf.function
def put_weights(image, mask, value):
value = tf.where(tf.equal(mask, 5), tf.multiply(value, 0.4), value)
value = tf.where(tf.equal(mask, 3), tf.multiply(value, 2), value)
value = tf.where(tf.equal(mask, 22), tf.multiply(value, 0.4), value)
value = tf.where(tf.equal(mask, 21), tf.multiply(value, 0.4), value)
value = tf.where(tf.equal(mask, 20), tf.multiply(value, 0.4), value)
value = tf.where(tf.equal(mask, 15), tf.multiply(value, 0.4), value)
value = tf.where(tf.equal(mask, 17), tf.multiply(value, 0.4), value)
value = tf.where(tf.equal(mask, 10), tf.multiply(value, 0.4), value)
value = tf.where(tf.equal(mask, 11), tf.multiply(value, 0.4), value)
value = tf.where(tf.equal(mask, 19), tf.multiply(value, 0.55), value)
value = tf.where(tf.equal(mask, 18), tf.multiply(value, 0.8), value)
return image, mask, value
I map this function on my dataset so I will obtain a dataset containing images (256x256x3), real mask (256x256x1) and the weights(256x256x1).
CodePudding user response:
Yes, sample_weight
(or class_weight
) is exactly what you need. Adding weights is usually done for balancing the number of samples in a dataset. However even though this is not your case, it could help you improve the accuracy of your hard-to-learn classes.
You said:
So I gave each pixel of each image a weight (so I incremented grass weight and decremented the tree ones).
I agree with incrementing the grass weight, however I would have chosen a different, smaller, weight for the trees, not negative. But I guess that if this works for you, it's ok.
Also if you are looking for a way to improve your code I suggest you have a look at this answer.