[code; podcast] Grokking, the unusual phenomenon for algorithmic datasets where generalization happens long after overfitting the training data, has remained elusive.
We aim to understand grokking by analyzing the loss landscapes of neural networks, identifying the mismatch between training and test losses as the cause for grokking.
We refer to this as the āLU mechanismā because training and test losses (against model weight norm) typically resemble āLā and āUā, respectively [based on the Goldilocks zone of small weight initialization]. This simple mechanism can nicely explain many aspects of grokking: data size dependence, weight decay dependence, the emergence of representations, etc.
Guided by the intuitive picture, we are able to induce grokking [Omnigrok] on tasks involving images, language and molecules.
In the reverse direction, we are able to eliminate grokking for algorithmic datasets.
We attribute the dramatic nature of grokking for algorithmic datasets to representation learning.
Figure 1: (a) w: L2 norm of model weights. Generalizing solutions (green stars) are concentrated around a sphere in the weight space where w ā wc (green). Overfitting solutions (orange) populate the w & wc region.
(b) The training loss (orange) and test loss (gray) have the shape of L and U, respectively. Their mismatch in the w > wc region leads to fast-slow dynamics, resulting in grokking.
ā¦LU mechanism: Although the loss landscapes of neural networks are nonlinear, FortāÆ&āÆScherlis2019 reveal a simple landscape picture: There is a spherical shell in the weight space (the āGoldilocksā zone), where generalization is better than outside this zone. We illustrate the Goldilocks zone as the green area with average radius wc in Figure 1a; the green stars are the generalizing solutions.
The test loss is thus higher either both when w > wc and w < wc, forming a U-shape against w in Figure 1b (gray curve). By contrast, the training loss has an L-shape against weight norm 2. There are many solutions which overfit training data for w > wc, but high training losses are incurred for w < wc. This corresponds to the L-shaped curve seen in Figure 1b (orange curve, no regularization).
In summary, the (reduced) training loss and test loss are L-shaped and U-shaped against weight norm, respectively, which we will refer to as the LU mechanism throughout the paper.
ā¦Grokking dynamics: We identify the āLU mechanismā as the cause of grokking.
If the weight norm is initialized to be large (eg. the black square in the w > wc region), the model first quickly moves to a nearby overfitting solution by minimizing the training loss. Without any regularization, the model will stay where it is, because the gradient of the training loss is almost zero along the valley of overfitting solutions, so generalization does not happen.
Fortunately, there are usually explicit and/or implicit regularizations that can drive the weight vector towards the Goldilocks zone w ā wc.
When the regularization magnitude is non-zero but small, the radial motion can be (arbitrarily) slow. If weight decay is the only source of regularization, and training loss is negligible after overfitting, then weight decayγ causes w(t) ā exp(āγt)w0, when w0 > wc, so it takes time t ā ln(w0/wc)/γ ā γā1 to generalize. A small γ results in a huge generalization delay (ie. grokking).
The dependence on regularization magnitudes is illustrated in Figure 1b: no generalization at all happens for γ = 0, small γ leads to slow generalization (grokking), and large γ leads to faster generalization3. The above analysis only applies to large initializations w > wc. Small initializations w < wc can always generalize fast (but w should not be too small to harm optimization), regardless of regularization.
ā¦5. Representation Is Key To Grokking: In §4, we showed that increasing initialization scales can make grokking happen for standard ML tasks. However, this seems a bit artificial and does not explain why standard initialization leads to grokking on algorithmic datasets, but not on standard ML datasets, say MNIST.
The key difference is how much the task relies on representation learning. For the MNIST dataset, the quality of representation determines whether the test accuracy is 95% or 100%; by contrast in algorithmic datasets, the quality of representation determines whether test accuracy is random guess (bad representation) or 100% (good representation). So overfitting (under a bad representation) has a more dramatic effect on algorithmic datasets, i.e. the model weights increase quickly during overfitting but test accuracy remains low. During overfitting, model weight norm is much larger than at initialization, but then drops below the initialization norm when the model generalizes, shown in Figure 7a, and also observed by Nandaet al2023. As a byproduct, we are able to eliminate grokking by constraining the model on a small weight norm sphere, shown in Figure 7b.
Figure 7: Training 1L transformer on modular addition (p = 113).
(a) Weight norm, train accuracy, and test accuracy over time, initialized and trained normally. Weight norm first increases, and is highest during the period of overfitting, but then drops to become lower than initial weight norm when the model generalizes.
(b) Constrained optimization at constant weight norm (α = 0.8) largely eliminates grokking, with test and train accuracy improving concurrently.
The above picture is supported by a transformer experiment: Figure 7a, shows how model norm changes over time and we see that there is an initial increase in weight norm, which peaks during overfitting, but then drops during the period of generalization to be lower than the initialization norm
Figure 10: Time to generalize as a function of weight decay.
we investigate to what extent the relation t ā γā1 holds, where t is number of training steps needed for the model to generalize and γ is the AdamW weight decay. When a lower weight decay is used, models spend longer in the period of overfitting before eventually generalizing.
We show the generalization time t as a function of γ in (a, b) and full training curves for these runs in (c, d).