“To Grok or Not to Grok: Disentangling Generalization and Memorization on Corrupted Algorithmic Datasets”, Darshil Doshi, Aritra Das, Tianyu He, Andrey Gromov2023-10-19 (, , , )⁠:

Robust generalization is a major challenge in deep learning, particularly when the number of trainable parameters is very large.

In general, it is very difficult to know if the network has memorized a particular set of examples or understood the underlying rule (or both). Motivated by this challenge, we study an interpretable model where generalizing representations are understood analytically, and are easily distinguishable from the memorizing ones. Namely, we consider multi-layer perceptron (MLP) and Transformer architectures trained on modular arithmetic tasks, where (ξ × 100%) of labels are corrupted (i.e. some results of the modular operations in the training set are incorrect).

We show that (1) it is possible for the network to memorize the corrupted labels and achieve 100% generalization at the same time; (2) the memorizing neurons can be identified and pruned, lowering the accuracy on corrupted data and improving the accuracy on uncorrupted data; (3) regularization methods such as weight decay, dropout and BatchNorm force the network to ignore the corrupted data during optimization, and achieve 100% accuracy on the uncorrupted dataset; and (4) the effect of these regularization methods is (“mechanistically”) interpretable: weight decay and dropout force all the neurons to learn generalizing representations, while BatchNorm de-amplifies the output of memorizing neurons and amplifies the output of the generalizing ones.

Finally, we show that in the presence of regularization, the training dynamics involve two consecutive stages: first, the network undergoes grokking dynamics reaching high train and test accuracy; second, it unlearns the memorizing representations, where the train accuracy suddenly jumps 100% → 100 (1 − ξ)%.

…Grokking MLP on modular addition dataset is remarkably robust to label corruption. Even without explicit regularization, the model can generalize to near 100% accuracy with sizable label corruption. In many cases, the network surprisingly manages to “correct” some corrupted examples, resulting in Inversion (ie. test accuracy > training accuracy). We emphasise that this is in stark contrast to the common belief that grokking requires explicit regularization. Adding regularization makes grokking more robust to label corruption, with stronger Inversion.

Partial Inversion: In this phase, the network generalizes on the test data but only memorizes a fraction of the corrupted training data. Remarkably, the network predicts the “true” labels on the remaining corrupted examples. In other words, the network corrects a fraction of the corrupted data. Consequently, we see <100% train accuracy but near-100% test accuracy, resulting in a negative generalization gap (Figure 1b)!

We term this phenomenon Partial Inversion; where “inversion” refers to the test accuracy being higher than train accuracy.

Remarkably, partial inversion occurs even in the absence of any explicit regularization, but only when there is ample training data (leftmost panels in Figure 3a, Figure 3b).

Figure 3: Modular Addition phase diagrams with various regularization methods. A larger data fraction leads to more “correct” examples, leading to higher corruption-robustness. Increasing regularization, in the form of weight decay or dropout, enhances robustness to label corruption and facilitates better generalization.