Supporting state-of-the-art AI research requires balancing rapid prototyping, ease of use, and quick iteration, with the ability to deploy experiments at a scale traditionally associated with production systems. Deep learning frameworks such as TensorFlow, PyTorch and JAX allow users to transparently make use of accelerators, such as TPUs and GPUs, to offload the more computationally intensive parts of training and inference in modern deep learning systems. Popular training pipelines that use these frameworks for deep learning typically focus on (un-)supervised learning. How to best train reinforcement learning (RL) agents at scale is still an active research area.
In this report we argue that TPUs are particularly well suited for training RL agents in a scalable, efficient and reproducible way. Specifically we describe two architectures designed to make the best use of the resources available on a TPU Pod (a special configuration in a Google data center that features multiple TPU devices connected to each other by extremely low latency communication channels).
Figure 4: (a) FPS for Anakin, as a function of the number of TPU cores, ranging from 16 (ie. 2 replicas) to 128 (ie. 16 replicas). (b) FPS for a Sebulba implementation of IMPALA’s V-trace algorithm, as a function of the actor batch size, from 32 (as in IMPALA) to 128. (c) FPS for a Sebulba implementation of MuZero, as a function of the number of TPU cores, from 16 (ie. 2 replicas) to 128 (ie. 16 replicas).
Anakin: When using small neural networks and grid-world environments an Anakin architecture can easily perform 5 million steps per second, even on the 8-core TPU accessible for free through Google Colab.
This can be very useful to experiment and debug research ideas in the friendly Colab environment. In Figure 4a we show how, thanks to the efficient network connecting different TPU cores in a Pod, performance scales almost linearly with the number of cores; the collective operations used to average gradients across replicas appear to cause only minimal overhead.
In a recent paper by Ohet al2021 Anakin was used, at a much larger scale, to discover a general reinforcement learning update, from experience of interacting with a rich set of environments implemented in JAX. In this paper, Anakin was used to learn a single shared update rule from 60K JAX environments and 1K policies running and training in parallel.
Despite the complex nature of the system, based on the use of neural networks to meta-learn not just a policy but the entire RL update, Anakin delivered over 3 million steps per second on a 16-core TPU. Training the update rule to a good level of performance, required running Anakin for ~24 hours; this would cost ~$100 on GCP’s preemptible instances
…Sebulba: Our second podracer architecture has also been extensively used for exploring a variety of RL ideas at scale, on environments that cannot be compiled to run on TPU (eg. Atari, DMLab and MuJoCo). As both IMPALA and Sebulba are based on a decomposition between actors and learners, agents designed for the IMPALA architecture can be easily mapped onto Sebulba; for instance a Podracer version of the V-trace agent easily reproduced the results from Espeholtet al2018. However, we found that training an agent for 200 million frames of an Atari game could be done in just ~1 hour, by running Sebulba on a 8-core TPU. This comes at a cost of ~$2.88, on GCP’s pre-emptible instances. This is similar in cost to training with the more complex SEED RL framework, and much cheaper than training an agent for 200 million Atari frames using either IMPALA or single-stream GPU-based system such as that traditionally used by DQN.
…In addition to the trajectory length the effective batch size used to compute each update also depends on how many times we replicate the basic 8-TPU setup. Sebulba also scales effectively along this dimension: using 2048 TPU cores (an entire Pod) we were able to further scale all the way to 43 million frames per second, solving the classic Atari videogame Pong in less than 1 minute…Sebulba has also been used to train search-based agents inspired by MuZero (Schrittwieseret al2020). The workloads associated to these agents are very different from that of model-free agents like IMPALA. The key difference is in the cost of action selection. This increases because MuZero’s policy combines search with deep neural networks (used to guide and/or truncate the search). Typically, search-based agents like MuZero required custom C++ implementations of the search to deliver good performance. We could reproduce results from MuZero (no Reanalyse) on multiple RL environments, using Sebulba and a pure JAX implementation of MCTS. Training a MuZero agent with Sebulba for 200M Atari frames takes 9 hours on a 16-core TPU (at a cost of ~$40 on GCP’s preemptible instances).
We found that scalability, via replication, was particularly useful in this context. Figure 4c reports the number of frames per seconds processed by Sebulba when running MuZero on Atari, as a function of the number of TPU cores. The throughput increased linearly with the number of cores.