“Critical Data Size of Language Models from a Grokking Perspective”, Xuekai Zhu, Yao Fu, Bowen Zhou, Zhouhan Lin2024-01-19 (; backlinks)⁠:

We explore the critical data size in language models, a threshold that marks a fundamental shift from quick memorization to slow generalization. We formalize the phase transition under the grokking configuration into the Data Efficiency Hypothesis and identify data insufficiency, sufficiency, and surplus regimes in language models training dynamics.

We develop a grokking configuration to reproduce grokking on simplistic language models stably by rescaling initialization and weight decay.

We show that generalization occurs only when language models reach a critical size.

We analyze grokking across sample-wise and model-wise, verifying the proposed data efficiency hypothesis. Our experiments reveal smoother phase transitions occurring at the critical dataset size for language datasets. As the model size increases, this critical point also becomes larger, indicating that larger models require more data.

Our results deepen the understanding of language model training, offering a novel perspective on the role of data in the learning mechanism of language models.

Figure 1: Comprehensive analysis of training dynamics and accuracy curves verifies the data efficiency hypothesis on vanilla grokking15, 21. (A) Reproduced grokking phenomenon on modular addition using a 1-layer decoder-only Transformer trained on 2,000 samples. Delayed generalization (≈100% test acc) occurs during continuous training after memorization completion (≈100% train acc, overfitting). (B) Step-wise Analysis of Test Accuracy. We observe a clear peak indicating slow generalization at the critical data size, while more training samples markedly speed up generalization. Below the critical data size, no generalization happens. (C) Step-wise Analysis of Training Accuracy. Within 400 steps, the model can memorize all training data. Across various dataset sizes, there is a very small difference in memorization steps. (D) 1D PCA visualization of modular addition datasets. Data pruning uniformly samples from the initial distribution [ie. random deletion of data to reduce n]. (E) & (F): Test / Training accuracy across the whole training process. The detailed training process is presented in Figure 10.

…Discussion: why is the phase transition in language datasets smoother? We speculate there are two main reasons: the initial data size and task complexity.

Larger data sizes can lead to de-grokking7, 21. The model, training from a larger initial dataset, can learn a greater number of correlations. Consequently, the relationship between the phase transition and dataset size becomes increasingly smooth. The “Effective Theory of Representation Learning” suggests that tasks such as modular addition require more refined (linear) representations. However, representations for language tasks tend to be substantially more complex or “messier” (ie. nonlinear).

Figure 3: (A) We employ a 1-layer, encoder-only Transformer to trigger the grokking phenomenon on 10% Yelp data26. The delayed generalization occurs after overfitting. (B) Step-wise Analysis of Test Accuracy in Yelp. The generalization steps first increase and subsequently decrease as the data fraction grows, which is consistent with results on modular addition and IMDB datasets. (C): Step-wise Analysis of Training Accuracy in Yelp. Similar to experiments of modular addition and IMDB datasets, we obtain the same conclusion: memorization steps increase as the dataset size expands. The detailed training process is presented in Figure 11.
Figure 11: The training procedure on the Yelp dataset employed a 1-layer, encoder-only Transformer under the grokking framework.

…4.3 Grokking on Yelp: As shown in Figure 3A, we successfully induced the grokking phenomenon on Yelp. This was achieved using a smaller, 10% subset of the Yelp dataset.

As illustrated in Figure 11, we discovered that using larger datasets could lead to de-grokking, i.e. memorization and generalization co-occur.

To promote grokking in the Yelp dataset, we strategically pruned the data under the grokking configuration. Under various fractions of Yelp training samples, the results align with the proposed data efficiency hypothesis. This is evident in Figures 3B and 3C, where we observe two fundamental phenomena: (1) The presence of critical data size. However, the transition from slow to faster generalization has become smoother. As illustrated in Figure 11, reducing the data size 100% → 10% results in only about a 5% decrease in performance. (2) A trend where increasing the training sample size leads to a decrease in generalization steps. Compared with IMDB and modular addition datasets, the Yelp dataset contains more samples, which leads to faster model fitting.

Additional data did not lead to an improvement, but it did accelerate convergence. On the contrary, faster fitting weakens the manifestations associated with the data efficiency hypothesis, thus the phenomena we observe on Yelp are somewhat less pronounced

…Discussion: Why is grokking not commonly observed in large language models with big datasets?

  1. The model converges faster with larger datasets.

    If we have a ‘magnifying glass’ [eg. PassUntil?] to observe the learning process carefully, perhaps we can witness the grokking phenomenon in LLMs. Specifically, our experiments suggest that grokking can be more readily induced through strategic data pruning and grokking configuration. This approach essentially represents a ‘slow’ learning version (ie. reduce dataset size, increase initialization, decrease weight decay) of modern learning systems.

    From this perspective, we conjecture that grokking is a fundamental phenomenon hidden under complex conditions, which can only be seen under the dual effects of dataset pruning and grokking configuration.

  2. Large language models incorporate a variety of regularization methods, while our grokking simplistic model is limited to using only weight decay.

    As is well known, various regularization techniques in modern large models help accelerate convergence and prevent overfitting. In our simplified setting, the model’s convergence speed is slowed down, allowing us to observe clear phase changes.

Figure 4: Model-wise grokking experiments on IMDB demonstrate that the critical data size increases as the model size increases. (A) Test accuracy variations by hidden layer size and data fraction of the IMDB dataset. The data fraction required for higher accuracy increases as the model size increases. Training acc visualization is presented in Figure 14. (B) Average accuracy across all data fractions 10% → 100%. The white arrows indicate that the average accuracy decreases as the model size increases, suggesting that larger models require more data to maintain performance. The light blue area represents a 95% confidence interval. (C) Training curves for models with different layer counts under various data fractions. As the number of layers increases, larger models require larger data sizes for effective generalization.
Figure 4: Model-wise grokking experiments on IMDB demonstrate that the critical data size increases as the model size increases.
(A) Test accuracy variations by hidden layer size and data fraction of the IMDB dataset. The data fraction required for higher accuracy increases as the model size increases. Training acc visualization is presented in Figure 14.
(B) Average accuracy across all data fractions 35.9%–59.4% The white arrows indicate that the average accuracy decreases as the model size increases, suggesting that larger models require more data to maintain performance. The light blue area represents a 95% confidence interval.
(C) Training curves for models with different layer counts under various data fractions. As the number of layers increases, larger models require larger data sizes for effective generalization.
Figure 5: Visualization about how the model transits from memorization to generalization throughout the training process. We visualize the classification layer’s weights during the learning process using a 1-layer, encoder-only Transformer on the IMDB dataset. Notably, the parameter distribution evolves from a randomly initialized state to a fixed range of values, which we have categorized into stages from A to F. The transition from memorization to generalization is influenced by weight decay and loss, leading to a decrease in the 𝓁2 norm. More explanations of the 𝓁2 norm evolution are in Figure 7.