“S4: Efficiently Modeling Long Sequences With Structured State Spaces”, Albert Gu, Karan Goel, Christopher RĂ©2021-10-31 (, ; backlinks; similar)⁠:

[cf. LSSL, HiPPO; Github (example); talk; explainer] A central goal of sequence modeling is designing a single principled model that can address sequence data across a range of modalities and tasks, particularly on long-range dependencies. Although conventional models including RNNs, CNNs, and Transformers have specialized variants for capturing long dependencies, they still struggle to scale to very long sequences of 10,000 or more steps.

A promising recent approach proposed modeling sequences by simulating the fundamental state space model (SSM) xâ€Č(t) = Ax(t) + Bu(t), y(t) = Cx(t) + Du(t), and showed that for appropriate choices of the state matrix A, this system could handle long-range dependencies mathematically and empirically. However, this method has prohibitive computation and memory requirements, rendering it infeasible as a general sequence modeling solution.

We propose the Structured State Space sequence (S4) model based on a new parameterization for the SSM, and show that it can be computed much more efficiently than prior approaches while preserving their theoretical strengths. Our technique involves conditioning A with a low-rank correction, allowing it to be diagonalized stably and reducing the SSM to the well-studied computation of a Cauchy kernel.

S4 achieves strong empirical results across a diverse range of established benchmarks, including (1) 91% accuracy on sequential CIFAR-10 with no data augmentation or auxiliary losses, on par with a larger 2-D ResNet, (2) substantially closing the gap to Transformers on image and language modeling tasks, while performing generation 60× faster (3) SoTA on every task from the Long Range Arena benchmark, including solving the challenging Path-X task of length 16k that all prior work fails on, while being as efficient as all competitors.

[Parrot:

I find it easiest to think of it as a “super RNN”—an RNN with all the long-term dependency and vanishing gradient issues fixed. My best TLDR for why it works:

  1. It’s like a linear RNN with an N-dimensional hidden state. x

  2. The key is to initialize and parameterize the RNN in a very special way.

  3. This makes xt evolve in a special way: each xt lets you reconstruct all past inputs u0, u1, 
 ut with high accuracy.

    IIUC, just being able to “memorize” like this is apparently enough to break SOTA on Long Range Arena.

  4. And with the special initialization, the RNN’s parameter matrix is so simple that you can compute a very large number of time steps entirely in parallel, using FFT. FFT is the key computational trick; the other part is initializing with a matrix that is “almost diagonal” and therefore easy to work with.

TLDR: RNNs can be really really good if you parameterize them the right way.

Narendra Patwardhan:

Transformers would perform better than S4 (in its current form) on any task which can’t be easily expressed as a simple differential equation such as language modeling, question answering, object detection, image segmentation etc.]