Recent research on the grokking phenomenon has illuminated the intricacies of neural networksā training dynamics and their generalization behaviors. Grokking refers to a sharp rise in the networkās generalization accuracy on the test set, which occurs long after an extended overfitting phase, during which the network perfectly fits the training set.
While the existing research primarily focuses on shallow networks such as 2-layer MLP and 1-layer Transformer, we explore grokking on deep networks (eg. 12-layer MLP). [following Liuet al2022] We empirically replicate the phenomenon and find that deep neural networks can be more susceptible to grokking than their shallower counterparts.
Meanwhile, we observe an intriguing multi-stage generalization phenomenon when increasing the depth of the MLP model where the test accuracy exhibits a secondary surge, which is scarcely seen on shallow models. We further uncover compelling correspondences between the decreasing of feature ranks and the phase transition from overfitting to the generalization stage during grokking. Additionally, we find that the multi-stage generalization phenomenon often aligns with a double-descent pattern in feature ranks.
These observations suggest that internal feature rank could serve as a more promising indicator of the modelās generalization behavior compared to the weight-norm. We believe our work is the first to dive into grokking in deep neural networks and investigate the relationship between feature rank and generalization performance.
ā¦Our experiments reveals that deeper MLP networks are more prone to experiencing grokking compared to their shallower counterparts, which implies more sophisticated features learnt by deeper layers not help much with generalization.
As illustrated in Figure 2aāc, as the network depth increases, both the growth in training and test accuracy are notably delayed.
Multi-stage generalization: We also observed an intriguing multi-stage generalization behavior on deep MLP networks, wherein the test accuracy demonstrated several distinct periods of improvement marked by sharp transitions. According to Figure 2dāf, when trained on 7,000 data points, no substantial grokking was observed during the training of the 4-layer MLP network, while both the 8-layer and 12-layer MLP networks exhibited a notable second surge of test accuracy. Meanwhile, compared to the 8-layer, the occurrences of both the first & second stage of generalization are delayed with the 12-layer model, and the test accuracy obtained from the first-stage generalization is substantially lower than the 8-layer network.
Figure 2: Generalization behaviors of various depth of MLP models.
(Top) Grokking: On small training set (|Dtrain| = 5,000), both the improvement of training (blue curve) and test (orange curve) accuracy would be delayed as the depth of the model increase.
(Bottom) Multi-stage Generalization: On a larger training set (|Dtrain| = 7,000), the deep (8 & 12-layer) MLP models exhibit a 2-stage generalization, where the test accuracy experiences a second surge. The 2-stage generalization phenomenon is also correlated to a double-descent of feature rank (Figure 1c).
ā¦3. Dependence on Weight Decay: ā¦As shown in Figure 6, on 2,000 training data points, the smallest regularization (γ = 0.005) fails to generalize; with γ = 0.01, the model demonstrates a delayed generalization (grokking); with the largest regularization (γ = 0.05), the model exhibits a two-stage generalization.