I have a conceptual question about K fold Cross validation. In general, we train a model to learn based on test data and validate it with test data, and we assume the system is blind to this data, and this is why we can evaluate if the system really learnt or not. Now with k fold, the final model actually have seen (indirectly, though) all data, so why it is still valid??? It already has seen all data and we do not know how it predicts unseen data. This is my question that based on this fact, why we know this method valid? Thanks.
CodePudding user response:
In K-Fold Cross Validation, you actually train K different models. Let's say we are doing 5-Fold CV and the size of the dataset is 100 samples. Then, in each fold, we randomly split the data as 80 train samples and 20 test samples. We train on 80 train samples then we test the trained model on 20 left-out test samples. We compute accuracy and note it. At the end, we will have 5 different models. Then, we can average the accuracies of each fold and report this as the average performance of the model. Coming to your question, actually you need to think why we need K-Fold Cross Validation. The answer is, you need to report the performance of you model, right? However, if you just train and evaluate your model with single split, then there is a possibility that your model may be biassed to this specific split. I mean, in this split, a rare case may come out like a highly domain shift between train and test sets which is bad for the performance.
CodePudding user response:
TL;DR: Think of your 'test data' more like 'validation data', which you hope represents truly unseen test data. Ideally if the model performs well for many different validation datasets it will work well when applied to real life test data which wasn't used in the training-validation process.
This confusion is justified. You are correct.
This is where the terminology training data, validation data and test data can make things more clear. Models are trained on training data. This is data directly seen by the model to go through the process of updating its parameters and learn. Validation data is data the we use to validate how well the model has actually learned. It is not directly seen by the model and we use it to judge things like under or overfitting. It is assumed that the validation data is a good representation of test data. Test data is what we will end up applying our model to in the real world, it have never been seen in any way by the model.
Test and validation data are often used interchangeably, with most people just using training and test terminology.
An example: If you are build a cat detector you collect images of cats, you split these images into training and validation sets. You assume the validation set is an accurate representation of the kinds of cat images people will use your model on in the real world. You train your model on the training data, validate how well it has learned on the validation data and once you think it has learned well you deploy the model. People will use it on their own images to detect cats. These images are the true test data, which have never been seen by the model, but hopefully your validation set was a good indicator of how you model will perform on these images.
K-fold cross validation is best used when your validation set may be small, or you are unsure of how well it represents test data (e.g. if there are only ginger cats in your validation set, it lead to your model failing on test data, so you would like to mix the validation set up). By performing k-fold cross validation you can validate your model more times, with different choices of validation set, which hopefully will give a better indication of your model's generalizability.