I have always had some doubts on grid search. I am not sure how I should conduct grid search for hyperparameter tuning for a model and report the model’s generalization performance for a scientific paper.
There are three possible ways:
1) Split data into 10 folds. Repeat 10 times of the following: pick 9 folds as training data, train a model with a specific set of hyperparameters on the training data, test it on the remaining fold and record the test accuracy. Finally, pick the set of hyperparameters achieving the highest mean test accuracy and report that mean test accuracy.
This procedure is deployed in scikit learn GridSearchCV and also in a cross validated post (note that this post claimed to implement a nested cross validation but I think it is not). But the problem is that the accuracy you finally report is not based on unseen data (as pointed out by https://datascience.stackexchange.com/questions/21877/how-to-use-the-output-of-gridsearch “This goes against the principles of not using test data!!” part). This procedure is also claimed as flawed in another cross validated post: https://stats.stackexchange.com/questions/224287/cross-validation-misuse-reporting-performance-for-the-best-hyperparameter-value?noredirect=1&lq=1
2) Separate out a subset of data as test data. Use the remaining data to pick the best set of hyperparameters as in 1). Then, retrain the model with the best set of hyperparameters on all the data except the test data. Finally, report the accuracy of the retrained model on the test data.
The problem of this method is that the test data is only a small portion of the whole dataset and is tested only once. So the accuracy on the test data may have large variance.
3) Use nested cross validation: http://scikit-learn.org/stable/auto_examples/model_selection/plot_nested_cross_validation_iris.html. Split data into 10 folds. Repeat 10 times of the following: pick 9 folds as a dataset, apply 1) to pick the best set of hyperparameters, retrain a new model on the dataset based on the best set of hyperparameters, and test the retrained model on the remaining fold to obtain an accuracy. Finally, report the mean accuracy across 10 times. This essentially involves two cross validations, often called inner cross validation (for picking the best hyperparameters) and outer cross validation (for reporting the mean accuracy as the generalization error).
The problem of this method is that in each time, you might end up with a different set of hyperparameters. So the mean accuracy reported finally is not guaranteed to average over the models of the same set of hyperparameters.
Based on my research on cross validated forum, 3) is the most correct way. If you end up with the best model with the same set of hyperparameters across 10 times, then that’s perfect since you can just ultimately retrain a model based on the whole dataset with that consistent set of hyperparameters (if you use the scikit learn script given above for nested cross validation, you get the final model by setting refit=True
: https://stats.stackexchange.com/a/281824/80635). If you get models with different hyperparameters in each time, that means your model training is not stable, and picking any model of the 10 models is not fair. In this case, either collect more data or debug your model training process until models become stable across 10 times.
Some useful links to help understand:
from https://stats.stackexchange.com/a/72324/80635:
https://www.elderresearch.com/blog/nested-cross-validation
https://stats.stackexchange.com/questions/65128/nested-cross-validation-for-model-selection
Now, I’d like to share one paragraph I used to write in my paper submission:
It is different than 3) in that in each outer cross validation, hyperparameter selection is only based on one shot of training and validation (i.e., there is no inner cross validation).
I have also seen many papers just splitting the whole dataset into training/validation/test set. So this procedure just removes outer cross validation as well. I think people won’t complain about it if the dataset is very large.
At last, I want to share a famous paper, if you have not read it, Random Search for Hyper-Parameter Optimization, which claims that random search can do as well as grid search in many practical problems. llustration can be found here: https://stats.stackexchange.com/questions/160479/practical-hyperparameter-optimization-random-vs-grid-search