“Why Are Sensitive Functions Hard for Transformers?”, Michael Hahn, Mark Rofin2024-02-15 (, )⁠:

Empirical studies have identified a range of learnability biases and limitations of transformers, such as a persistent difficulty in learning to compute simple formal languages such as PARITY, and a bias towards low-degree functions. However, theoretical understanding remains limited, with existing expressiveness theory either overpredicting or underpredicting realistic learning abilities.

We prove that, under the transformer architecture, the loss landscape is constrained by the input-space sensitivity: Transformers whose output is sensitive to many parts of the input string inhabit isolated points in parameter space, leading to a low-sensitivity bias in generalization.

We show theoretically and empirically that this theory unifies a broad array of empirical observations about the learning abilities and biases of transformers, such as their generalization bias towards low sensitivity and low degree, and difficulty in length generalization for PARITY [and success of inner-monologue methods].

This shows that understanding transformers’ inductive biases requires studying not just their in-principle expressivity, but also their loss landscape.

…Intermediate Steps Reduce Sensitivity

In the realm of multiplication over finite monoids, foundational to regular languages and the simulation of finite automata (eg. Liu et al 2023; Delétang et al 2023; Angluin et al 2023), average sensitivity creates a dichotomy: If f indicates n-fold multiplication over a finite monoid, then a⁢sn⁢(f) = 𝒪⁢(1) if the monoid is aperiodic and a⁢sn⁢(f) = Θ⁢(n) else. The theory thus predicts that transformer shortcuts to automata (Liu et al 2023) will be brittle for many non-aperiodic automata.

PARITY can be solved well with a scratchpad (Anil et al 2022; Liu et al 2023). Existing theoretical accounts of the benefit of intermediate steps for transformers’ expressive capacity (eg. Merrill & Sabharwal

2023a; Feng et al 2023) do not account for the benefit of intermediate steps for PARITY-like problems, as the upper bounds on transformer abilities used in these studies do not account for the hardness of learning to compute PARITY in a single step (Fact 1). The concept of average sensitivity provides a simple explanation. Formally, we can consider the problem of simulating a finite automaton with state set 𝒳 either translating to the final state tn in one go (‘standard’), or to autoregressively translate it into a sequence of states t1,…,tn (‘scratchpad’). Then (proof in Appendix D): Theorem 7.

Simulating an automaton with scratchpad has sensitivity 𝒪⁢(1) for each autoregressive step.

…Scratchpad Eliminates Sharpness: By Theorem 7, sensitivity of each autoregressive step when computing PARITY with scratchpad is 𝒪(1). Hence, Theorem 6 provides no nontrivial lower bound for Lρ,n(T). We trained an Encoder-Decoder Transformer, predicting PARITY of i-th substring on i-th autoregressive step: ti = PARITY(x1:i) = xi ⊕ ti−1 (t0 = 0). The visual dependency between sharpness and length of input for PARITY with a scratchpad is shown in Figure 11. Even for length around 300, sharpness is low and there is little increase with input length. Thus, decrease in sensitivity due to the scratchpad can explain why prior work (Anil et al 2022) found that PARITY is easy for Transformers with scratchpad…Our theory suggests that transformers generalize well to the extent that real-world data has bounded sensitivity (eg. Hahn et al 2021).

[Why inner-monologue helps: breaks down into individual steps with much tamer loss landscapes and thus more learnable.]