[Twitter; code] Large language models like GPT-4 exhibit emergent capabilities across general-purpose tasks, such as basic arithmetic, when trained on extensive text data, even though these tasks are not explicitly encoded by the unsupervised, next-token prediction objective.
This study investigates how small transformers [NanoGPT & GPT-2], trained from random initialization, can efficiently learn arithmetic operations such as addition, multiplication, and elementary functions like square root, using the next-token prediction objective. We first demonstrate that conventional training data is not the most effective for arithmetic learning, and simple formatting changes can improve accuracy.
This leads to sharp phase transitions as a function of training data scale, which, in some cases, can be explained through connections to low-rank matrix completion. Building on prior work, we then train on chain-of-thought style data that includes intermediate step results. Even in the complete absence of pretraining, this approach and simultaneously improves accuracy, sample complexity, and convergence speed.
We also study the interplay between arithmetic and text data during training and examine the effects of few-shot prompting, pretraining, and model scale. Additionally, we discuss length generalization challenges.
Our work highlights the importance of high-quality, instructive data that considers the particular characteristics of the next-word prediction objective for rapidly eliciting arithmetic capabilities.
…Data format and sampling matters: We first observe that teaching a model addition (or any other operation) using standard addition samples, ie. A3A2A1 + B3B1B1 = C3C2C1, is suboptimal, as it requires the model to evaluate the most large digit C3 of the result first, which depends globally on all the digits of the two summands. By training on samples with reversed results, ie. A3A2A1 + B3B1B1 = C1C2C3, we enable the model to learn a simpler function, improving sample complexity. Additionally, balanced sampling of different “variations” of addition, based on the number of carries and digits involved, further enhances learning. Even in this simple setting, we observe relatively sharp phase transitions 0–100% accuracy as a function of the size of the training data. Although this may seem surprising, we observe that learning an addition map on n digits from random samples is equivalent to completing a low-rank matrix. This connection allows us to offer a reasonable explanation for such phase transitions.
Figure 1: The 4 data formatting methods investigated in this work: (1) Plain: standard addition formatting (§4), (2) Reverse: reversing the output (§4), (3) Simplified Scratchpad: recording the digit-wise sum and carry-ons (§6), and (4) Detailed Scratchpad: providing detailed intermediate steps of addition (§6). We train small transformer models from scratch using data transformed with these various formatting methods for addition. The results (shown on the right) highlight the crucial role of data formatting in performance and sample efficiency. Plain never reaches 100% accuracy and the sample complexity for the remaining methods to learn addition perfectly steadily reduces as we increase the level of detail in the data format.
Figure 2: The 4 input formatting methods used for the addition task. We progressively increase the amount of detail with each format.
…Chain-of-thought data during training: …We found that CoT-type training data importantly improved learning in terms of both sample complexity and accuracy in agreement with CoT fine-tuning literature (Nyeet al2021; Chunget al2022), though our observation holds even in the absence of language pretraining. We conjecture that this is because breaking down the required compositional function to be learned into individual components allows the model to learn a higher-dimensional but easier-to-learn function map. In Figure 1, we provide examples of the 4 data formatting methods explored in our work.
…Instructional data/chain-of-thought: The idea of using detailed reasoning training data predates Transformers (Vaswaniet al2017). Linget al2017; Cobbeet al2021; Nyeet al2021 use natural language to generate reasoning steps while Roy & Roth2016; Reed & De Freitas2015; Chenet al2017; Caiet al2017 show that symbolic reasoning may suffice. Nogueiraet al2021 note that large number of samples with small digits is important for arithmetic tasks (Yuanet al2023). Razeghiet al2022 observe a correlation between the frequency of numbers in the dataset and the performance involving them whereas we find that transformers can learn to add numbers that were not seen during training. Chain-of-thought (Weiet al2022c) refers to the model’s improved performance when prompted to produce rationale. Zhouet al2022b show that this can be achieved by providing sufficiently informative exemplars as a few-shot prompt (Brownet al2020). Zhouet al2022a showed that least-to-most prompting can help GPT-3 solve problems that can be decomposed into simpler sub-problems. Least-to-most prompting consists of first decomposing a complex problem into easier subproblems, and then sequentially solving these subproblems. We extend this notion to simple addition and show that asking the model to output the least important bit first has a similar effect…Zaremba & Sutskever2014 show that RNNs can learn how to execute simple programs with for-loops provided they are trained with curriculum learning…encoder-decoder models have also been extensively studied in the literature in the context of learning arithmetic (Kimet al2021; Wanget al2021). Qianet al2022; Lightmanet al2023; Uesatoet al2022 explore techniques to improve the arithmetic abilities of pretrained LLMs. Wallaceet al2019 on the other hand, focus on the impact of the learned embeddings…Charton2022 & Charton2021 show that Transformers can learn linear algebra operations with carefully chosen encodings. Hannaet al2023 use mechanistic interpretability techniques to explain the limited numerical reasoning capabilities of GPT-2.
…We find that learning all arithmetic operations discussed earlier (from addition to square root) can improve the individual performance of each task, and that going from zero-shot to 1-shot prompting (showing one arithmetic example) yields a large accuracy improvement, but there is no large improvement in accuracy by showing more examples. [Consistent with the Bayesian/meta-learning interpretation of task location: there are not many possibilities so more examples don’t offer any relevant evidence—either it knows the necessary task, or it doesn’t.]
Figure 3: Performance of 3-digit addition on various data sampling methods used: (1) Random: uniform sampling of operands; (2) Balanced digits: assigning higher sampling weights to operations involving 1 and 2-digit numbers; (3) Balanced carry: balancing the dataset to contain an equal number of carry-on operations. Experiments on addition with zero-padding both operands and output to have 3 and 4 digits respectively.
Figure 4: Comparison of NanoGPT model performance on the addition task, trained on plain and reverse formatted data. The conventional plain format exhibits suboptimal performance, even with a larger number of addition examples, whereas a distinct phase transition is observed for the reverse format around 2,500 train samples where it learns addition perfectly.
Figure 5a: We run Algorithm 1, a simple iterative algorithm for 2-rank matrix completion for the addition matrix (n = 20, 50, 100, 500) and report the success probability over multiple random trials while varying the number of revealed entries. As anticipated, a sharp phase transition occurs when ~𝒪(n) entries are revealed.
Figure 5b: We compare the performance of a NanoGPT model trained on a dataset containing n = 100 samples (ie. 2-digit addition) to that of the corresponding LRMC problem using the same sample set. Notably, the phase transition at around 1,500 samples, where both NanoGPT and Algorithm 1 begin learning addition almost flawlessly, is remarkably similar.
…The phase transition of LRMC offers insights into NanoGPT’s learning process. Nevertheless, further experiments clearly demonstrate that NanoGPT’s mechanism for learning addition is fundamentally different from LRMC. It can successfully learn addition even when numbers or digits are intentionally excluded from the training data, thereby exhibiting generalization capabilities that far exceed that of typical LRMC algorithms.
Figure 6: Comparison of sample efficiency: evaluating performance on training datasets with different numbers of addition samples. While all modified methods (Reverse, Simplified Scratchpad, and Detailed Scratchpad) achieve 100% test accuracy, they exhibit varying requirements in terms of the number of addition examples in the training dataset to reach optimal performance.
…the Detailed Scratchpad format, which provides even more detailed information, achieves perfect addition with just 1,000 samples. This indicates that incorporating more information enables the model to learn addition more efficiently, requiring fewer examples. We conjecture that this is because breaking down the required compositional function to be learned into individual components allows the model to learn a higher-dimensional but easier-to-learn function map. We note that while CoT-style training enhances sample efficiency, it may not necessarily be the most “token-efficient” approach. We delve into this aspect in more detail in §11. In summary, incorporating scratchpad data and decomposing the addition task into steps offer a promising strategy to improve the performance and efficiency of small models in learning addition from scratch…[but] the detailed scratchpad method uses considerably more tokens compared to other techniques.
…Anilet al2022 suggests that models can only perform out-of-distribution tasks by combining fine-tuning, prompting, and scratchpad techniques. Nonetheless, there have been cases where length generalization was observed. Nyeet al2021 demonstrated length generalization but only for models with more than 108 parameters.
…Noisy intermediate steps in the scratchpad data: We further investigate the importance of providing accurate intermediate steps in the scratchpad during the training process. While this was inspired by the findings of Minet al2022, it is inherently different. Minet al2022 show that using random labels in ICL [in-context learning] demonstrations caused minimal degradation when compared to the gold labels. However, those models were trained on gold labels and then evaluated on multiple downstream tasks. In our setting, the model is trained and evaluated on a single arithmetic task. Further, the final result (or label) is left untouched as the correct answer to the arithmetic operation. We only replace the intermediate steps. The goal of this study is to verify whether the model actually learns to reason using the given intermediate steps or merely uses the scratchpad to improve its expressivity. We compare the performance of training with our simplified scratchpad formatting, which includes accurate A (digit sum) and C (carry) information, with formatting that includes random A, random C, or random A and C for each intermediate step, as depicted in Figure 1.
The results in Figure 9, demonstrate that the inclusion of noisy labels can impede sample efficiency. However, with enough samples, the model ultimately achieves full accuracy. This suggests that while the model is capable of leveraging the information contained in the intermediate steps, it can also gradually learn how to perform addition while disregarding the presence of noisy intermediate steps.
Figure 9: Comparison of training with simplified scratchpad formatting using correct A and C information with formatting using random A/C and their effect on sample efficiency and accuracy. Results show that noisy labels degrade sample efficiency, but with sufficient training data, the model eventually reaches full accuracy.
Figure 15: Performance of 3−digit subtraction, 2−digit multiplication, 4−digit precision sine and square root with varying data formats. As with addition, reverse always produces improved sample complexity and performance for all operations. For sine and square root, scratchpad formatting provides limited improvement. This discrepancy can be attributed to the complexity of the intermediate steps involved in the detailed scratchpad.
…The results presented in Table 3 reveal intriguing findings. We observe that the reverse format consistently outputs a result that deviates by no more than 1 from the true answer, regardless of whether the preceding outputs O3O2 are subjected to random or precise perturbation. This consistency can be explained by Lemma 2, indicating that the reverse format only requires learning a straightforward function of digit-wise addition for each corresponding position, along with the carry-on (0 or 1). Therefore, even with noise in the preceding tokens, the model accurately performs digit-wise addition, albeit with occasional carry-on prediction errors. With an exact accuracy of 81.26% even in the presence of random perturbation, the reverse format demonstrates the model’s ability to rely less on the preceding output tokens, indicating a robust learned output mapping.
…Going from character-level tokenization to BPE: …Figure 20 shows that GPT-2 demonstrates high performance in addition tasks with both character-level tokenization and Tiktoken with spaces between digits. This aligns with the results by Wallaceet al2019, suggesting that character-level tokenization exhibits stronger numeracy capabilities compared to a word or sub-word methods. Furthermore, comparing the models trained from scratch and the models trained from the pretrained model, we observe that fine-tuning a pretrained model results in better performance compared to training a model from scratch.
Figure 20: Performance of various configurations of the GPT-2 model on the addition task. We compare the effects of tokenization methods, specifically character-level tokenization versus Tiktoken (OpenAI’s BPE tokenizer), training initialization (training from scratch versus training from a pretrained GPT-2 model), and the inclusion or exclusion of spaces between numbers. The results highlight the importance of using pretrained models and incorporating spaces for consistent tokenization of numbers when training a model for arithmetic tasks.