“DoReMi: Optimizing Data Mixtures Speeds Up Language Model Pretraining”, Sang Michael Xie, Hieu Pham, Xuanyi Dong, Nan Du, Hanxiao Liu, Yifeng Lu, Percy Liang, Quoc V. Le, Tengyu Ma, Adams Wei Yu2023-05-17 (, )⁠:

[Twitter] The mixture proportions of pretraining data domains (eg. Wikipedia, books, web text) greatly affect language model (LM) performance. In this paper, we propose Domain Reweighting with Minimax Optimization (DoReMi), which first trains a small proxy model using group distributionally robust optimization (Group DRO) over domains to produce domain weights (mixture proportions) without knowledge of downstream tasks.

We then resample a dataset with these domain weights and train a larger, full-sized model. In our experiments, we use DoReMi on a 280M-parameter proxy model to find domain weights for training an 8B-parameter model (30× larger) more efficiently.

On The Pile, DoReMi improves perplexity across all domains, even when it downweights a domain. DoReMi improves average few-shot downstream accuracy by 6.5% points over a baseline model trained using The Pile’s default domain weights and reaches the baseline accuracy with 2.6× fewer training steps.

On the GLaM dataset, DoReMi, which has no knowledge of downstream tasks, even matches the performance of using domain weights tuned on downstream tasks.

[…we actually found that loss rankings transfer very well across scales even at the example level (in some prelim tests, 95%+ Spearman rank correlation). So it seems: more learnable for a large model → more learnable for a small model.]

Figure 2: DoReMi optimizes domain weights with a small model (280M params) and uses these domain weights to train a much larger model (8B params, 30× larger). Here, optimizing the domain weights (training a small model twice) takes 8% of the compute of training the large model. DoReMi improves average one-shot downstream accuracy by 6.5% points and reaches the baseline accuracy 2.6× faster when pretraining on The Pile.

DoReMi can reduce perplexity across all domains without a tradeoff: Figure 4 shows the per-domain perplexity of the 8B models on The Pile. DoReMi substantially reduces the perplexity over the baseline across all domains, despite allocating lower weight to some domains. How can this occur? Intuitively, the domains with the lowest and highest entropy can be downweighted without impacting the perplexity much. The lowest entropy domains statistically require few samples to learn. The highest entropy domains have token distributions that are close to common uniform priors —for example, models at random initialization tend to output a uniform next token distribution. Thus, we need less samples to fit these domains. Positive transfer from allocating more samples to medium entropy domains can then improve perplexity on all domains. In Appendix D, we provide a simple example where reweighting domains can improve perplexity on all domains and DoReMi finds such domain weights in simulations.

What is a domain? We define a domain by data provenance in our experiments, but this only enables coarse-grained control. Using fine-grained domains could improve the gains from DoReMi. For example, DoReMi is more effective on The Pile (22 domains) than the GLaM dataset (8 domains).