“TTT-NN: Test-Time Training on Nearest Neighbors for Large Language Models”, Moritz Hardt, Yu Sun2023-05-29 (, , ; backlinks)⁠:

[code; cf. AlphaZero retrieval, training on doc clusters for better context use] Many recent efforts aim to augment language models with relevant information retrieved from a database at test time. We avoid the need for prompt engineering by directly fine-tuning the model on data retrieved at test time using its standard training setup. For this purpose, we build a large-scale distributed nearest neighbor index based on text embeddings of the Pile dataset. Given a query to a language model, our system retrieves the neighbors of the query and fine-tunes the model on the text data corresponding to those neighbors.

Surprisingly, retrieving and training on as few as 20 neighbors, each for only one gradient iteration, drastically improves performance across more than 20 language modeling tasks in the Pile benchmark.

For example, test-time training narrows the performance gap between a small GPT-2 model and a GPT-Neo model, more than 10× larger, that was specifically trained to convergence on the Pile. Sufficient index quality and size, however, are important.

Our work establishes a valuable first baseline for implementing test-time training in the context of large language models, opening the door to numerous promising research avenues.

[This is analogous to retrieval-augmented AlphaGo: dynamic evaluation : tree search, in terms of expert iteration on an imperfect model at runtime—equivalent to taking a single large gradient descent step onto the current problem (since self-attention ≈ gradient descent at runtime)]

Figure 5: Bits per byte results on all Pile tasks for a small GPT-2 model (117M parameters) before and after training on 50 nearest neighbors.

…We begin with an evaluation of test-time training on a small GPT-2 model with 117M parameters, the default HuggingFace gpt2 model from the transformers library. See Figure 5.

Figure 6 showcases the decrease in 3 different perplexity measures as we increase the number of nearest neighbors to train on. We can see that using 20 neighbors already achieves most of the decrease in perplexity, computationally costing less than half than using all the neighbors. Additional neighbors continue to decrease perplexities. Figure 6 focuses on the 6 largest Pile tasks in increasing order. These 6 tasks combined make up more than 70% of the Pile benchmark.

Figure 6: How different perplexities decrease on average with the number of neighbors on the 6 largest Pile tasks in ascending order Results for GPT-2 (117M parameters). Top: Bits per byte. Center: Byte perplexity. Bottom: Word perplexity.
Figure 6: How different perplexities decrease on average with the number of neighbors on the 6 largest Pile tasks in ascending order. Results for GPT-2 (117M parameters).
Top: Bits per byte. Center: Byte perplexity. Bottom: Word perplexity.
Figure 7: Results for GPT-2-Large (774M parameters). Top: Before and after TTT-NN with 50 neighbors on the top 6 tasks. Bottom: How bits per byte decrease with additional neighbors
Figure 7: Results for GPT-2-Large (774M parameters).
Top: Before and after TTT-NN with 50 neighbors on the top 6 tasks. Bottom: How bits per byte decrease with additional neighbors.
Figure 8: Results for GPT-Neo (1.3b parameters). Top: Before and after TTT-NN with 50 neighbors on the top 6 tasks. Bottom: How bits per byte decrease with each additional neighbor.
Figure 9: Training costs in seconds per neighbor on each task.

…Many methods use retrieved neighbors as additional context for the test instance, without test-time training. Unlike ours, models for those methods need to be trained also with retrieval and additional context. Due to the prohibitive cost of training with retrieval, we experiment with a baseline that simply uses the neighbors in-context at test time. Specifically, we concatenate the neighbors in increasing distance, and include as much of the concatenated text as possible into the context window of the test instance, in addition to its original context. This baseline improves performance in a few cases, eg. pile_enron, but does not help much overall.

…As long as the database contains some high-quality data from a particular underrepresented group, test-time training may benefit the group. Our experiments on the Pile give some support for this intuition, as the smaller tasks with less than 1% size, eg. pile_enron and pile_europarl, see much more substantial improvements than the larger ones, eg. pile_pile-cc and pile_pubmed-central, with more than 30% size together.

Test-time training might also help mitigate adversarial behaviors, including data poisoning attacks, by superimposing data at test time from a trusted data source. We hope to see more future work in this context.