âWhy Are Sensitive Functions Hard for Transformers?â, 2024-02-15 ()â :
Empirical studies have identified a range of learnability biases and limitations of transformers, such as a persistent difficulty in learning to compute simple formal languages such as PARITY, and a bias towards low-degree functions. However, theoretical understanding remains limited, with existing expressiveness theory either overpredicting or underpredicting realistic learning abilities.
We prove that, under the transformer architecture, the loss landscape is constrained by the input-space sensitivity: Transformers whose output is sensitive to many parts of the input string inhabit isolated points in parameter space, leading to a low-sensitivity bias in generalization.
We show theoretically and empirically that this theory unifies a broad array of empirical observations about the learning abilities and biases of transformers, such as their generalization bias towards low sensitivity and low degree, and difficulty in length generalization for PARITY [and success of inner-monologue methods].
This shows that understanding transformersâ inductive biases requires studying not just their in-principle expressivity, but also their loss landscape.
âŚIntermediate Steps Reduce Sensitivity
In the realm of multiplication over finite monoids, foundational to regular languages and the simulation of finite automata (eg. et al 2023; et al 2023; et al 2023), average sensitivity creates a dichotomy: If f indicates n-fold multiplication over a finite monoid, then aâ˘snâ˘(f) = đŞâ˘(1) if the monoid is aperiodic and aâ˘snâ˘(f) = Îâ˘(n) else. The theory thus predicts that transformer shortcuts to automata ( et al 2023) will be brittle for many non-aperiodic automata.
PARITY can be solved well with a scratchpad ( et al 2022; et al 2023). Existing theoretical accounts of the benefit of intermediate steps for transformersâ expressive capacity (eg.
2023a; et al 2023) do not account for the benefit of intermediate steps for PARITY-like problems, as the upper bounds on transformer abilities used in these studies do not account for the hardness of learning to compute PARITY in a single step (Fact 1). The concept of average sensitivity provides a simple explanation. Formally, we can consider the problem of simulating a finite automaton with state set đł either translating to the final state tn in one go (âstandardâ), or to autoregressively translate it into a sequence of states t1,âŚ,tn (âscratchpadâ). Then (proof in Appendix D): Theorem 7.
Simulating an automaton with scratchpad has sensitivity đŞâ˘(1) for each autoregressive step.
âŚScratchpad Eliminates Sharpness: By Theorem 7, sensitivity of each autoregressive step when computing PARITY with scratchpad is đŞ(1). Hence, Theorem 6 provides no nontrivial lower bound for LĎ,n(T). We trained an Encoder-Decoder Transformer, predicting PARITY of i-th substring on i-th autoregressive step: ti = PARITY(x1:i) = xi â tiâ1 (t0 = 0). The visual dependency between sharpness and length of input for PARITY with a scratchpad is shown in Figure 11. Even for length around 300, sharpness is low and there is little increase with input length. Thus, decrease in sensitivity due to the scratchpad can explain why prior work ( et al 2022) found that PARITY is easy for Transformers with scratchpadâŚOur theory suggests that transformers generalize well to the extent that real-world data has bounded sensitivity (eg. et al 2021).
[Why inner-monologue helps: breaks down into individual steps with much tamer loss landscapes and thus more learnable.]