Guides on choosing the model architecture, the optimizer, and the batch size.
The goal of tuning batch size is to saturate the GPUs, which can be monitored by:
training throughput = (# examples processed per second)
time per step = (batch size) / (training throughput)
P.S. When batch size is tuned, most hyper-parameters need to be tuned. Among them, the learning rate and the regularization term are the most important.
A scientific approach to improving model performance
Design the experiment
Scientific hyper-parameters
The experiment is aimed at explore the effect of the scientific hyper-parameters.
Nuisance hyper-parameters
To compare different approach, nuisance hyper-parameters are needed to be tuned and the best trails are to be compared.
Fixed hyper-parameters
The hyper-parameters to be fixed to reduce the number of trails needed.
The exploration of parameters search space
Bayesian optimization / quasi-random search. Considering the search boundary carefully.
Automate the plotting to ensure we plot enough graphics.
Determining the number of steps for each training run
Deciding how long to train when training is not compute-bound
Deciding how long to train when training is compute-bound
"Round 1: Shorter runs to find good model and optimizer hyperparameters."
"Round 2: Very few long runs on good hyperparameter points to get the final model."
Additional guidance for the training pipeline
Optimizing the input pipeline
Saving checkpoints and retrospectively selecting the best checkpoint
Keep the best k checkpoints along training.
FAQs
How should Adam’s hyper-parameters be tuned?
Why use quasi-random search instead of more sophisticated black box optimization algorithms during the exploration phase of tuning?
Unstable training
Learning rate warmup
"Our goal is to find the shortest number of warmup_steps that allows us to access peak learning rates that are much higher than unstable_base_learning_rate." The default is 10x unstable_base_learning_rate.
Gradient clipping
"Choose a gradient clipping threshold based on the 90th percentile of gradient norms."
Issue with Batch Normalization: Use x + f(Norm(x)). Norm(x + f(x)) is known to cause issues.