“Scaling Vision Transformers to 22 Billion Parameters”, Mostafa Dehghani, Josip Djolonga, Basil Mustafa, Piotr Padlewski, Jonathan Heek, Justin Gilmer, Andreas Steiner, Mathilde Caron, Robert Geirhos, Ibrahim Alabdulmohsin, Rodolphe Jenatton, Lucas Beyer, Michael Tschannen, Anurag Arnab, Xiao Wang, Carlos Riquelme, Matthias Minderer, Joan Puigcerver, Utku Evci, Manoj Kumar, Sjoerd van Steenkiste, Gamaleldin F. Elsayed, Aravindh Mahendran, Fisher Yu, Avital Oliver, Fantine Huot, Jasmijn Bastings, Mark Patrick Collier, Alexey Gritsenko, Vighnesh Birodkar, Cristina Vasconcelos, Yi Tay, Thomas Mensink, Alexander Kolesnikov, Filip Pavetić, Dustin Tran, Thomas Kipf, Mario Lučić, Xiaohua Zhai, Daniel Keysers, Jeremiah Harmsen, Neil Houlsby2023-02-10 (, , , )⁠:

The scaling of Transformers has driven breakthrough capabilities for language models. At present, the largest large language models (LLMs) contain upwards of 100b parameters. Vision Transformers (ViT) have introduced the same architecture to image and video modeling, but these have not yet been successfully scaled to nearly the same degree; the largest dense ViT contains 4b parameters (Chen et al 2022).

We present a recipe for highly efficient and stable training of a 22b-parameter ViT (ViT-22b) and perform a wide variety of experiments on the resulting model.

When evaluated on downstream tasks (often with a lightweight linear model on frozen features), ViT-22b demonstrates increasing performance with scale. We further observe other interesting benefits of scale, including an improved tradeoff between fairness and performance, state-of-the-art alignment to human visual perception in terms of shape/texture bias, and improved robustness.

ViT-22b demonstrates the potential for “LLM-like” scaling in vision, and provides key steps towards getting there.

…For example, even when used as a frozen visual feature extractor, ViT-22b achieves an accuracy of 89.5% on ImageNet. With a text tower trained to match these visual features (Zhai et al 2022b), it achieves 85.9% accuracy on ImageNet in the zero-shot setting. The model is furthermore a great teacher—used as a distillation target, we train a ViT-B student that achieves 88.6% on ImageNet, state-of-the-art at this scale. This performance comes with improved out of distribution behavior, reliability, uncertainty estimation and fairness tradeoffs. Finally, the model’s features are better aligned with humans perception, achieving previously unseen shape bias of 87%.

Dataset: ViT-22b is trained on a version of JFT (Sun et al 2017), extended to around 4b images (Zhai et al 2022a) [JFT-4b]. These images have been semi-automatically annotated with a class-hierarchy of 30k labels. Following the original Vision Transformer, we flatten the hierarchical label structure and use all the assigned labels in a multi-label classification fashion employing the sigmoid cross-entropy loss.

…Using these techniques, ViT-22b processes 1.15k tokens per second per core during training (forward and backward pass) on TPUv4 (Jouppi et al 2020). ViT-22b’s model flops usage (MFU) (Chowdhery et al 2022; Dehghani et al 2021a) is 54.9%, indicating a very efficient use of the hardware. Note that PaLM reports 46.2% MFU (Chowdhery et al 2022; Pope et al 2022) and we measured 44.0% MFU for ViT-e (data-parallel only) on the same hardware.

4.5.2 Human Alignment: How well do ViT-22b classification decisions align with human classification decisions? Using the model-vs-human toolbox (Geirhos et al 2021), we evaluate 3 ViT-22b models fine-tuned on ImageNet with different resolutions (224, 384, 560). Across all toolbox metrics, ViT-22b is SOTA: ViT-22b-224 for highest OOD robustness (Figure 19(a)), ViT-22b-384 for the closest alignment with human classification accuracies (Figure 19(b)), and ViT-22b-560 for the largest error consistency (ie. most human-like error patterns, Figure 19(d)). The ViT-22b models have the highest ever recorded shape bias in vision models: while most models have a strong texture bias (approx. 20–30% shape bias / 70–80% texture bias) (Geirhos et al 2019); humans are at 96% shape / 4% texture bias and ViT-22b-384 achieves a previously unseen 87% shape bias / 13% texture bias (Figure 8). Overall, ViT-22b measurably improves alignment to human visual object recognition.

Figure 8: Shape bias: many vision models have a low shape / high texture bias, whereas ViT-22b fine-tuned on ImageNet (<span style=“color: red”;>red, <span style=“color: green”;>green, <span style=“color: blue”;>blue trained on 4b images as indicated by brackets after model names, unless trained on ImageNet only) have the highest shape bias recorded in a ML model to date, bringing them closer towards a human-like shape bias.
Figure 8: Shape bias: many vision models have a low shape / high texture bias, whereas ViT-22b fine-tuned on ImageNet (red, green, blue trained on 4b images as indicated by brackets after model names, unless trained on ImageNet only) have the highest shape bias recorded in a ML model to date, bringing them closer towards a human-like shape bias.

4.5.4 Calibration: Along with the robustness of §4.2.3, it is also natural to wonder how the calibration property of ViT evolves as the scale increases. To this end, we focus on the study of Minderer et al 2021 that we extend with ViT-22b. In Figure 9, we consider ViT-22b fine-tuned on ImageNet (resolution 384) and report the error (ie. one minus accuracy) versus the calibration, as measured by the expected calibration error (ECE) (Naeini et al 2015; Guo et al 2017). We see how ViT-22b remarkably improves the tradeoff between accuracy and calibration. The conclusion holds both without (left) and with (right) a temperature-scaling of the logits that was observed to better capture the calibration trends across model families (Minderer et al 2021). More details can be found in Appendix H.

4.5.5 Distillation: We perform model distillation (Hinton et al 2015) to compress the ViT-22b into smaller, more widely usable ViTs. We distill ViT-22b into ViT-B/16 and ViT-L/16 by following the procedure of Beyer et al 2022b. Using ImageNet-finetuned (at 384px) ViT-22b, we annotated 500 random augmentations and MixUp transforms of each ImageNet image with ViT-22b logits. Then, we minimize the KL divergence between the student and the teacher predictive distributions. We train for 1,000 epochs after initializing the student architecture from checkpoints pre-trained on JFT. The results are shown in Table 8, and we see that we achieve new SOTA on both the ViT-B and ViT-L sizes.