video, example code] >Large Transformer models yield impressive results on many tasks, but are expensive to train, or even fine-tune, and so slow at decoding that their use and study becomes out of reach.
We address this problem by leveraging sparsity. We study sparse variants for all layers in the Transformer and propose Scaling Transformers, a family of next generation Transformer models that use sparse layers to scale efficiently and perform unbatched decoding much faster than the standard Transformer as we scale up the model size. Surprisingly, the sparse layers are enough to obtain the same perplexity as the standard Transformer with the same number of parameters.
We also integrate with prior sparsity approaches to attention and enable fast inference on long sequences even with limited memory.
This results in performance competitive to the state-of-the-art on long text summarization.
…With the growing popularity and size of these models, it is increasingly valuable to make them scale efficiently. In this work we propose Scaling Transformers with a separate sparse mechanism for the query, key, value and output layers (QKV layers for short) and combine it with sparse feedforward blocks to get a fully sparse Transformer architecture.
…We were surprised that the fully sparse Scaling Transformers are indeed enough to match the results of the baseline Transformer on the large C4 dataset (Figure 1). The improvement in complexity holds not just asymptotically but yields over 2.6× speedup in wall-clock hed decoding time already for a model with 800M parameters and 20× improvement for a model with 17b parameters, as shown in Table 1:
Figure 1: Log-perplexity of Scaling Transformers (equivalent to T5 large with ~800M parameters) on C4 dataset with proposed sparsity mechanisms (FF, QKV, FF+QKV) is similar to baseline dense model. Other models used in this paper are shown in grey lines; raw data is available in the appendix.
To verify that Scaling Transformers can be used with other Transformer improvements on real tasks, we create Terraformer—a Transformer model that uses reversible layers for memory efficiency and sparse attention to handle long sequences. We pre-train Terraformer on the C4 dataset and fine-tune it on the challenging task of summarizing arxiv articles. Terraformer yields results competitive to the state-of-the-art BigBird-Pegasus without using the Pegasus loss in pre-training (Table 5).
…We also checked the performance of the feedforward block with Mixture-of-Experts style sparsity. As expected, this technique achieved decoding time comparable to sparse FF—0.11s instead of 0.09s—but with its lack of granularity it achieved log-perplexity of 1.64, worse than both our method and the dense baseline.
…4.3 Recurrence for Generalization: In addition to incorporating sparse attention and reversibility, we also add recurrence to the feedforward block of Terraformer. Recurrent layers allow information to propagate in time, even in a single decoder block. It is challenging though to use them without decreasing model speed, especially in training. For that reason, we use simple recurrent units which parallelize well during training.
SRUs contain dense layers, so their use could negate the benefits of sparsity elsewhere. We tried a few methods to alleviate that, but it turns out that simply reducing the dimensionality of the SRUs works. So we first project from d model to a small dimension (32 in our experiments), then apply the SRU, and then project back to d model and add the result to the feedforward block. This low-rank recurrence is in our experiments sufficient to transfer enough information through time for the network to generalize. Since the effects of SRUs on C4 are minimal (as the training and evaluation data are very similar), we use synthetic tasks to investigate out-of-distribution generalization. We train the models on long addition and on the task of copying a decimal digit. We train on inputs with at most 128 digits and evaluate on inputs lengths 256–300, so over 2× longer. As can be seen in the table below, the baseline Transformer does not generalize well, while Terraformer manages to get a large portion correctly, even if it is not perfect like the Neural GPU.
…Table 6 shows the speedup in decoding with sparse layers when we scale up Terraformer to 17b parameters. Note that sparsifying all the layers gives us 37× speedup in decoding.
…Further, we hope that the community will take inspiration from Scaling Transformers and tune them for their needs. We ran experiments using layer sizes and hyperparameters borrowed from dense Transformers and they are most probably not optimal for Scaling Transformer. With proper tuning and further improvements we believe one could train a Scaling Transformer to match GPT-3 in accuracy but also run inference in reasonable time on a laptop. We put it as a fascinating challenge to the community, since such Scaling Transformers will not only be more sustainable but will also make large models accessible to everyone.