“GPT-J-6B: 6B JAX-Based Transformer”, 2021-06-08 (; backlinks; similar):
Summary:
We have released GPT-J-6B, 6B JAX-based (Mesh) Transformer LM (Github).
GPT-J-6B performs nearly on par with 6.7B GPT-3 (or Curie) on various zero-shot down-streaming tasks.
You can try out this Colab notebook or free web demo.
This library also serves as an example of model parallelism with xmap on JAX.
Below, we will refer to GPT-J-6B by GPT-J in short.
Why does this project matter?
GPT-J is the best-performing publicly available autoregressive Transformer LM in terms of zero-shot performance on various down-streaming tasks. [There are public T5 checkpoints but they are bidirectional.]
GPT-J allows more flexible and faster inference than Tensorflow + TPU counterparts.
This project required a substantially smaller amount of person-hours than other large-scale model developments did, which demonstrates that JAX + xmap + TPUs is the right set of tools for quick development of large-scale models.
Credit assignment:
- Ben Wang
Wrote the code and the Colab notebook, built a part of API and ran experiments.
- Aran Komatsuzaki
Proposed this project, designed the high-level plan and the configs, wrote this article and advised Ben.
[GPT-J has been avidly picked up by GPT API users, including NovelAI’s Japanese model, Vietnamese, Korean, French, AI Dungeon, academic research as a baseline (eg. code generation or accuracy), story critique, PurpleSmart, self-imitations]
View External Link: