information flow in transformers
In machine learning, the transformer architecture is a very commonly used type of neural network model. Many of the well-known neural nets introduced in the last few years use this architecture, including GPT-2, GPT-3, and GPT-4.
This post is about the way that computation is structured inside of a transformer.
Internally, these models pass information around in a constrained way that feels strange and limited at first glance.
Specifically, inside the “program” implemented by a transformer, each segment of “code” can only access a subset of the program’s “state.” If the program computes a value, and writes it into the state, that doesn’t make value available to any block of code that might run after the write; instead, only some operations can access the value, while others are prohibited from seeing it.
This sounds vaguely like the kind of constraint that human programmers often put on themselves: “separation of concerns,” “no global variables,” “your function should only take the inputs it needs,” that sort of thing.
However, the apparent analogy is misleading. The transformer constraints don’t look much like anything that a human programmer would write, at least under normal circumstances. And the rationale behind them is very different from “modularity” or “separation of concerns.”
(Domain experts know all about this already – this is a pedagogical post for everyone else.)
1. setting the stage
For concreteness, let’s think about a transformer that is a causal language model.
So, something like GPT-3, or the model that wrote text for @nostalgebraist-autoresponder.
Roughly speaking, this model’s input is a sequence of words, like [“Fido”, “is”, “a”, “dog”].
Since the model needs to know the order the words come in, we’ll include an integer offset alongside each word, specifying the position of this element in the sequence. So, in full, our example input is
(“Fido”, 0),
(“is”, 1),
(“a”, 2),
(“dog”, 3),
The model itself – the neural network – can be viewed as a single long function, which operates on a single element of the sequence. Its task is to output the next element.
Let’s call the function f. If f does its job perfectly, then when applied to our example sequence, we will have
f(“Fido”, 0) = “is”
f(“is”, 1) = “a”
f(“a”, 2) = “dog”
(Note: I’ve omitted the index from the output type, since it’s always obvious what the next index is.
Also, in reality the output type is a probability distribution over words, not just a word; the goal is to put high probability on the next word. I’m ignoring this to simplify exposition.)
You may have noticed something: as written, this seems impossible!
Like, how is the function supposed to know that after (“a”, 2), the next word is “dog”!? The word “a” could be followed by all sorts of things.
What makes “dog” likely, in this case, is the fact that we’re talking about someone named “Fido.”
That information isn’t contained in (“a”, 2). To do the right thing here, you need info from the whole sequence thus far – from “Fido is a”, as opposed to just “a”.
How can f get this information, if its input is just a single word and an index?
This is possible because f isn’t a pure function. The program has an internal state, which f can access and modify.
But f doesn’t just have arbitrary read/write access to the state. Its access is constrained, in a very specific sort of way.
2. transformer-style programming
Let’s get more specific about the program state.
The state consists of a series of distinct “memory regions” or “blocks,” which have an order assigned to them.
Let’s use the notation memory_i for these. The first block is memory_0, the second is memory_1, and so on.
In practice, a small transformer might have around 10 of these blocks, while a very large one might have 100 or more.
Each block contains a separate data-storage “cell” for each offset in the sequence.
For example, memory_0 contains a cell for position 0 (“Fido” in our example text), and a cell for position 1 (“is”), and so on. Meanwhile, memory_1 contains its own, distinct cells for each of these positions. And so does memory_2, etc.
So the overall layout looks like:
memory_0: [cell 0, cell 1, …]
memory_1: [cell 0, cell 1, …]
Our function f can interact with this program state. But it must do so in a way that conforms to a set of rules.
Here are the rules:
- The function can only interact with the blocks by using a specific instruction.
- This instruction is an “atomic write+read”.
It writes data to a block, then reads data from that block for f to use. - When the instruction writes data, it goes in the cell specified in the function offset argument. That is, the “i” in f(…, i).
- When the instruction reads data, the data comes from all cells up to and including the offset argument.
- The function must call the instruction exactly once for each block.
- These calls must happen in order.
For example, you can’t do the call for memory_1 until you’ve done the one for memory_0.
Here’s some pseudo-code, showing a generic computation of this kind:
f(x, i) {
calculate some things using x and i;
// next 2 lines are a single instruction
write to memory_0 at position i;
z0 = read from memory_0 at positions 0…i;
calculate some things using x, i, and z0;
// next 2 lines are a single instruction
write to memory_1 at position i;
z1 = read from memory_1 at positions 0…i;
calculate some things using x, i, z0, and z1;
The rules impose a tradeoff between the amount of processing required to produce a value, and how early the value can be accessed within the function body.
Consider the moment when data is written to memory_0. This happens before anything is read (even from memory_0 itself).
So the data in memory_0 has been computed only on the basis of individual inputs like (“a,” 2). It can’t leverage any information about multiple words and how they relate to one another.
But just after the write to memory_0, there’s a read from memory_0. This read pulls in data computed by f when it ran on all the earlier words in the sequence.
If we’re processing (“a”, 2) in our example, then this is the point where our code is first able to access facts like “the word ‘Fido’ appeared earlier in the text.”
However, we still know less than we might prefer.
Recall that memory_0 gets written before anything gets read. The data living there only reflects what f knows before it can see all the other words, while it still only has access to the one word that appeared in its input.
The data we’ve just read does not contain a holistic, “fully processed” representation of the whole sequence so far (“Fido is a”). Instead, it contains:
- a representation of (“Fido”, 0) alone, computed in ignorance of the rest of the text
- a representation of (“is”, 1) alone, computed in ignorance of the rest of the text
- a representation of (“a”, 2) alone, computed in ignorance of the rest of the text
Now, once we get to memory_1, we will no longer face this problem. Stuff in memory_1 gets computed with the benefit of whatever was in memory_0. The step that computes it can “see all the words at once.”
Nonetheless, the whole function is affected by a generalized version of the same quirk.
All else being equal, data stored in later blocks ought to be more useful. Suppose for instance that
- memory_4 gets read/written 20% of the way through the function body, and
- memory_16 gets read/written 80% of the way through the function body
Here, strictly more computation can be leveraged to produce the data in memory_16. Calculations which are simple enough to fit in the program, but too complex to fit in just 20% of the program, can be stored in memory_16 but not in memory_4.
All else being equal, then, we’d prefer to read from memory_16 rather than memory_4 if possible.
But in fact, we can only read from memory_16 once – at a point 80% of the way through the code, when the read/write happens for that block.
The general picture looks like:
- The early parts of the function can see and leverage what got computed earlier in the sequence – by the same early parts of the function.
This data is relatively “weak,” since not much computation went into it. But, by the same token, we have plenty of time to further process it. - The late parts of the function can see and leverage what got computed earlier in the sequence – by the same late parts of the function.
This data is relatively “strong,” since lots of computation went into it. But, by the same token, we don’t have much time left to further process it.
3. why?
There are multiple ways you can “run” the program specified by f.
Here’s one way, which is used when generating text, and which matches popular intuitions about how language models work:
- First, we run f(“Fido”, 0) from start to end.
The function returns “is.”
As a side effect, it populates cell 0 of every memory block. - Next, we run f(“is”, 1) from start to end.
The function returns “a.”
As a side effect, it populates cell 1 of every memory block. - Etc.
If we’re running the code like this, the constraints described earlier feel weird and pointlessly restrictive.
By the time we’re running f(“is”, 1), we’ve already populated some data into every memory block, all the way up to memory_16 or whatever.
This data is already there, and contains lots of useful insights.
And yet, during the function call f(“is”, 1), we “forget about” this data – only to progressively remember it again, block by block. The early parts of this call have only memory_0 to play with, and then memory_1, etc. Only at the end do we allow access to the juicy, extensively processed results that occupy the final blocks.
Why? Why not just let this call read memory_16 immediately, on the first line of code? The data is sitting there, ready to be used!
Why? Because the constraint enables a second way of running this program.
The second way is equivalent to the first, in the sense of producing the same outputs. But instead of processing one word at a time, it processes a whole sequence of words, in parallel.
Here’s how it works:
- In parallel, run f(“Fido”, 0) and f(“is”, 1) and f(“a”, 2), up until the first write+read instruction.
You can do this because the functions are causally independent of one another, up to this point.
We now have 3 copies of f, each at the same “line of code”: the first write+read instruction. - Perform the write part of the instruction for all the copies, in parallel.
This populates cells 0, 1 and 2 of memory_0. - Perform the read part of the instruction for all the copies, in parallel.
Each copy of f receives some of the data just written to memory_0, covering offsets up to its own.
For instance, f(“is”, 1) gets data from cells 0 and 1. - In parallel, continue running the 3 copies of f, covering the code between the first write+read instruction and the second.
- Perform the second write.
This populates cells 0, 1 and 2 of memory_1. - Perform the second read.
- Repeat like this until done.
Observe that mode of operation only works if you have a complete input sequence ready before you run anything.
(You can’t parallelize over later positions in the sequence if you don’t know, yet, what words they contain.)
So, this won’t work when the model is generating text, word by word.
But it will work if you have a bunch of texts, and you want to process those texts with the model, for the sake of updating the model so it does a better job of predicting them.
This is called “training,” and it’s how neural nets get made in the first place. In our programming analogy, it’s how the code inside the function body gets written.
The fact that we can train in parallel over the sequence is a huge deal, and probably accounts for most (or even all) of the benefit that transformers have over earlier architectures like RNNs.
Accelerators like GPUs are really good at doing the kinds of calculations that happen inside neural nets, in parallel.
So if you can make your training process more parallel, you can effectively multiply the computing power available to it, for free. (I’m omitting many caveats here – see this great post for details.)
Transformer training isn’t maximally parallel. It’s still sequential in one “dimension,” namely the layers, which correspond to our write+read steps here. You can’t parallelize those.
But it is, at least, parallel along some dimension, namely the sequence dimension.
The older RNN architecture, by contrast, was inherently sequential along both these dimensions. Training an RNN is, effectively, a nested for loop. But training a transformer is just a regular, single for loop.
4. tying it together
The “magical” thing about this setup is that both ways of running the model do the same thing. You are, literally, doing the same exact computation. The function can’t tell whether it is being run one way or the other.
This is crucial, because we want the training process – which uses the parallel mode – to teach the model how to perform generation, which uses the sequential mode. Since both modes look the same from the model’s perspective, this works.
This constraint – that the code can run in parallel over the sequence, and that this must do the same thing as running it sequentially – is the reason for everything else we noted above.
Earlier, we asked: why can’t we allow later (in the sequence) invocations of f to read earlier data out of blocks like memory_16 immediately, on “the first line of code”?
And the answer is: because that would break parallelism. You’d have to run f(“Fido”, 0) all the way through before even starting to run f(“is”, 1).
By structuring the computation in this specific way, we provide the model with the benefits of recurrence – writing things down at earlier positions, accessing them at later positions, and writing further things down which can be accessed even later – while breaking the sequential dependencies that would ordinarily prevent a recurrent calculation from being executed in parallel.
In other words, we’ve found a way to create an iterative function that takes its own outputs as input – and does so repeatedly, producing longer and longer outputs to be read off by its next invocation – with the property that this iteration can be run in parallel.
We can run the first 10% of every iteration – of f() and f(f()) and f(f(f())) and so on – at the same time, before we know what will happen in the later stages of any iteration.
The call f(f()) uses all the information handed to it by f() – eventually. But it cannot make any requests for information that would leave itself idling, waiting for f() to fully complete.
Whenever f(f()) needs a value computed by f(), it is always the value that f() – running alongside f(f()), simultaneously – has just written down, a mere moment ago.
No dead time, no idling, no waiting-for-the-other-guy-to-finish.
The “memory blocks” here correspond to what are called “keys and values” in usual transformer lingo.
If you’ve heard the term “KV cache,” it refers to the contents of the memory blocks during generation, when we’re running in “sequential mode.”
Usually, during generation, one keeps this state in memory and appends a new cell to each block whenever a new token is generated (and, as a result, the sequence gets longer by 1).
This is called “caching” to contrast it with the worse approach of throwing away the block contents after each generated token, and then re-generating them by running f on the whole sequence so far (not just the latest token). And then having to do that over and over, once per generated token.
krautus liked this
highway61rerevisited liked this
wafflemelonman liked this
pseudocapsicum liked this
twiinkus liked this
adnarimel reblogged this from gorps
maeeeenndy liked this
cosmicvelocity liked this
browsens liked this
clawsofpropinquity liked this
glasshalftrue reblogged this from di--es---can-ic-ul-ar--es
look-here-luego reblogged this from di--es---can-ic-ul-ar--es
superdrivel liked this
backwardsandinhighheels liked this
aqueenvictorious reblogged this from nostalgebraist
pensivespacepirate reblogged this from cthulhubert
pensivespacepirate liked this
cthulhubert reblogged this from vash3r
neverserve reblogged this from multiheaded1793
hacvek liked this
addadashofpepper liked this
chaosinitiate liked this
learn-tilde-ath liked this
samsimisauser reblogged this from txttletale
omniraptorr liked this
good-lobster reblogged this from apas-95
octane liked this
ifritdiezel liked this
atomic-moose liked this
twoflour reblogged this from txttletale
rictic liked this
folieadeux04 liked this
purichana reblogged this from txttletale
springweeps liked this
cas-50-28-2 reblogged this from apas-95
falloffablog liked this
good-lobster liked this
fortunaminer reblogged this from machine-saint
fortunaminer liked this
prospitfox liked this
peesmelltime liked this
nostalgebraist posted this
In machine learning, the transformer architecture is a very commonly used type of neural network model. Many of the...
- Show more notes