[code, Twitter] We study whether transformers can learn to implicitly reason over parametric knowledge, a skill that even the most capable language models struggle with. Focusing on two representative reasoning types, composition and comparison, we consistently find that transformers can learn implicit reasoning, but only through grokking, ie. extended training far beyond overfitting.
The levels of generalization also vary across reasoning types: when faced with out-of-distribution examples, transformers fail to systematically generalize for composition but succeed for comparison. We delve into the modelās internals throughout training, conducting analytical experiments that reveal: (1) the mechanism behind grokking, such as the formation of the generalizing circuit and its relation to the relative efficiency of generalizing and memorizing circuits, and (2) the connection between systematicity and the configuration of the generalizing circuit.
Our findings guide data and training setup to better induce implicit reasoning and suggest potential improvements to the transformer architecture, such as encouraging cross-layer knowledge sharing.
Furthermore, we demonstrate that for a challenging reasoning task with a large search space, GPT-4-Turbo and Gemini-1.5-Pro based on non-parametric memory fail badly regardless of prompting styles or retrieval augmentation, while a fully grokked transformer can achieve near-perfect accuracy, showcasing the power of parametric memory for complex reasoning.
[Possible implication: scaling up LLMs may never trigger grokking without data-pruning if the scaled-up datasets simply maintain the ratio of memorable:learnable datapoints, leaving them in the inferior-generalizing regimes, requiring vastly larger sample sizes? See also Huanget al2024 on the harm a memorization task can do.]
Figure 1: We find that transformers can learn to reason implicitly, but this skill is only robustly acquired through grokking, i.e. an extended period of training far beyond overfitting. Moreover, the transformer fails to systematically generalize for composition, yet succeeds for comparison.
We conduct a mechanistic study into the model internals throughout grokking, which reveals distinct generalizing circuits across the two tasks (Figure 4, Figure 5) that explains the variations in systematicity.
ā¦On two representative reasoning types, composition and comparison, we consistently observe the ubiquitous role of grokking in transformerās acquisition of implicit reasoning. Further experiments reveal that the speed of grokking correlates with the ratio between inferred and atomic facts, and depends little on the absolute size of training data. This suggests a simple correction of prior explanations of grokking that the training data distribution, rather than size, may be the actual critical factor behind grokking. Moreover, the systematicity level varies across reasoning typesāin the OOD scenario, the model fails to systematically generalize for composition, but succeeds for comparison.
ā¦Inferred/atomic ratio Ļ correlates with generalization speed.
The Figure 2(a) shows the ID accuracy across different Ļ. We omit the other splits since for all settings, the training performance saturates quickly and the OOD accuracy remains at zero as earlier.3 It could be seen that the ratio Ļ strongly correlates with the speed of generalization. A very large ratio can push generalization to improve at a similar pace as the model fits the training data, reducing the need for extended training.4
Figure 2: The speed of grokking on test_inferredID (a) correlates with the ratio between inferred and atomic facts, and (b) is not influenced by the size of training data.
Training data distribution, instead of training data size, qualitatively influences generalization behavior. When Ļ increases and |ā°| holds constant, the size of training data also gets larger. Prior studies hypothesize that training data size plays a central role in order for grokking to happen. In particular, previous work connects grokking with the notion of critical data size (CDS)33, 61, 78, 21, where it is hypothesized that CDS marks the shift from memorization to generalization (via grokking), and the speed of generalization improves as the training data further scales.
However, results from our controlled experiments seem to contradict such a hypothesis. Figure 2(b) shows the results of varying |ā°| with a fixed Ļ = 9.0, where we change the horizontal axis from optimization step to epoch for better visualization.5 When fixing the ratio Ļ, the training data size does not qualitatively affect the modelās generalization. Specifically, scaling the data affects neither the relative speed of ID generalization and training improvement (as seen by the rather constant āgapā between train_inferredID and test_inferredID curves), nor the systematicity level (OOD performance stays zero). We also run the experiments across different Ļ and find the results to be consistent.
This suggests that critical data ādistributionā, not size, may be the actual deciding factor behind grokking and generalization. In addition, we find that scaling up the model size also does not qualitatively change the generalization behaviors observed here (Appendix B), and the main pattern is that larger models converge in fewer optimization steps, which shares with prior findings60, 28.
ā¦So what happens during grokking, why does it happen, and why do transformers exhibit different levels of systematicity in generalization?
To answer these questions, we analyze the model internals throughout grokking.
We find distinct generalizing circuits for the two reasoning types, and strong patterns in the evolution of the circuits that explain the underlying mechanisms behind grokking and variations in systematicity.
Notably, for the comparison task, the transformer gradually forms a parallel circuit which allows the model to store/access the ID/OOD atomic facts in the same region, enabling systematicity to happen.
Figure 4: The (evolution of) generalizing circuit for composition.
(a) The generalizing circuit.
(b) The change in causal strengths during grokking, where the target is the prediction state.
(c) Mean reciprocal rank (via logit lens) of the bridge entity b at Sā¢[5,r1] and second relation r2 at Sā¢[5,r2].
ā¦This also indicates that, before grokking, the model is very likely mostly memorizing the examples in train_inferredID by directly associating (h,r1,r2) with t, without going through the first hop.
ā¦Why does grokking happen? These observations suggest a natural explanation of why grokking happens through the lens of circuit efficiency. Specifically, as illustrated above, there exist both a memorizing circuit Cmem and a generalizing circuit Cgen that can fit the training data. While Cmem is learned first (which causes training performance to saturate quickly), Cgen is relatively more efficient, in the sense that it could fit the data with a lower complexity. To see this, we can compare the amount of facts Cmem and Cgen need to store (denoted as Nmem and Ngen) as a proxy for their complexity.7Cmem stores both atomic facts and inferred facts in the weights. Cgen (Figure 4(a)) stores the atomic facts in the lower layers, and another copy of the atomic facts that appear as the second hop in the inferred facts in the upper layers.
As the inferred/atomic ratio Ļ increases, Nmem would increase rapidly while Ngen increases slowly and is always bounded by two times the total amount of atomic facts, and hence, the relative efficiency of Cgen increases. In the long run, the model will be incentivized to transition from Cmem to Cgen due to implicit bias of the optimization53 and explicit regularization such as weight decay which prefers more efficient circuits, and the transition would happen faster as Ļ increases.
This also explains why the training data size does not affect the speed of grokking, since solely increasing the size does not change the relative efficiency of Cmem and Cgen. The explanation also implies that a larger regularization factor should accelerate grokking (and vice versa), which we confirm by varying the degree of weight decay (Appendix E.1).
Figure 13: Effect of weight decay.
A larger weight decay can improve the speed of grokking, and vice versa.
Figure 5: The (evolution of) generalizing circuit for comparison.
(a) The generalizing circuit.
(b) The change in causal strengths during grokking, where the target is the prediction state.
(c) Mean reciprocal rank (via logit lens) of the two attribute values (v1,v2) at Sā¢[5,e1] and Sā¢[5,e2].
ā¦Optimization is done by AdamW with learning rate 10ā4, batch size 512, weight decay 0.1 and 2,000 warm-up steps with a linear schedule. More details are included in Appendix A.