We’ve discovered that the gradient noise scale, a simple statistical metric, predicts the parallelizability of neural network training on a wide range of tasks. Since complex tasks tend to have noisier gradients, increasingly large batch sizes are likely to become useful in the future, removing one potential limit to further growth of AI systems. More broadly, these results show that neural network training need not be considered a mysterious art, but can be rigorized and systematized.
In an increasing number of domains it has been demonstrated that deep learning models can be trained using relatively large batch sizes without sacrificing data efficiency. However the limits of this massive data parallelism seem to differ from domain to domain, ranging from batches of tens of thousands in ImageNet to batches of millions in RL agents that play the game Dota 2. To our knowledge there is limited conceptual understanding of why these limits to batch size differ or how we might choose the correct batch size in a new domain. In this paper, we demonstrate that a simple and easy-to-measure statistic called the gradient noise scale predicts the largest useful batch size across many domains and applications, including a number of supervised learning datasets (MNIST, SVHN, CIFAR-10, ImageNet, Billion Word), reinforcement learning domains (Atari and Dota), and even generative model training (autoencoders on SVHN). We find that the noise scale increases as the loss decreases over a training run and depends on the model size primarily through improved model performance. Our empirically-motivated theory also describes the tradeoff between compute-efficiency and time-efficiency, and provides a rough model of the benefits of adaptive batch-size training.
The gradient noise scale (appropriately averaged over training) explains the vast majority (R2 = 80%) of the variation in critical batch size over a range of tasks spanning six orders of magnitude. Batch sizes are measured in either number of images, tokens (for language models), or observations (for games).
…We have found that by measuring the gradient noise scale, a simple statistic that quantifies the signal-to-noise ratio of the network gradients, we can approximately predict the maximum useful batch size. Heuristically, the noise scale measures the variation in the data as seen by the model (at a given stage in training). When the noise scale is small, looking at a lot of data in parallel quickly becomes redundant, whereas when it is large, we can still learn a lot from huge batches of data…We’ve found it helpful to visualize the results of these experiments in terms of a tradeoff between wall time for training and total bulk compute that we use to do the training (proportional to dollar cost). At very small batch sizes, doubling the batch allows us to train in half the time without using extra compute (we run twice as many chips for half as long). At very large batch sizes, more parallelization doesn’t lead to faster training. There is a “bend” in the curve in the middle, and the gradient noise scale predicts where that bend occurs.
Increasing parallelism makes it possible to train more complex models in a reasonable amount of time. We find that a Pareto frontier chart is the most intuitive way to visualize comparisons between algorithms and scales.
…more powerful models have a higher gradient noise scale, but only because they achieve a lower loss. Thus, there’s some evidence that the increasing noise scale over training isn’t just an artifact of convergence, but occurs because the model gets better. If this is true, then we expect future, more powerful models to have higher noise scale and therefore be more parallelizable. Second, tasks that are subjectively more difficult are also more amenable to parallelization…we have evidence that more difficult tasks and more powerful models on the same task will allow for more radical data-parallelism than we have seen to date, providing a key driver for the continued fast exponential growth in training compute.