A NEW PAELLA: SIMPLE & EFFICIENT TEXT-TO-IMAGE GENERATION
by: Dominic Rampas and Pablo Pernias, 15 Apr, 2023
Overview.
We are releasing a new Paella model which builds on top of our initial paper https://arxiv.org/abs/2211.07292. Paella is a text-to-image model that works in a quantized latent space and learns similarly to MUSE and Diffusion models. Paella is similar to MUSE as it also works on discrete tokens, but is different in the way tokens are noised as well as the architecture. MUSE uses a transformer, whereas we use a CNN, which comes with many benefits. There are also subtle differences in the conditioning Paella uses as well how images are sampled. And on the other hand, it can also be seen as a discrete diffusion process, which noises images during training and iteratively removes noise during sampling. Since the paper-release we worked intensively to bring Paella to a similar level as other state-of-the-art models. With this release we are coming a step closer to that goal. However, our main intention is not to make the greatest text-to-image model out there (at least for now), it is to bring text-to-image models closer to people outside the field on a technical basis. For example, many models have codebases with many thousand lines of code, that make it pretty hard for people to dive into the code and easily understand it. And that is our proudest achievement with Paella. The training and sampling code for Paella is minimalistic and can be understood in a few minutes, making further extensions, quick tests, idea testing etc. extremely fast. For instance, the entire sampling code can be written in just 12 lines of code. In this blog post we will talk about how Paella works in short, give technical details and release the model.
How does Paella work?
Paella works in a quantized latent space, just like StableDiffusion etc., to reduce the computational power needed.
Images are encoded to a smaller latent space and converted to visual tokens of shape h x w. During training,
these visual tokens are noised, by replacing a random amount of tokens with other randomly selected tokens
from the codebook of the VQGAN. The noised image are given to the model, along with a timestep and the conditional
information, which is text in our case. The model is tasked to predict the un-noised version of the tokens.
And that's it. The model is optimized with the CrossEntropy loss between the original tokens and the predicted tokens.
The amount of noise added during the training is just a linear schedule, meaning that we uniformly sample a percentage
between 0 and 100% and noise that amount of tokens.
Sampling is also extremely simple, we start with the entire image being random tokens. Then we feed the latent image,
the timestep and the condition into the model and let it predict the final image. The models outputs a distribution
over every token, which we sample from with standard multinomial sampling.
Since there are infinite possibilities for the result to look like, just doing a single step results in very basic
shapes without any details. That is why we add noise to the image again and feed it back to the model. And we repeat
that process for a number of times, with less noise being added every time, and slowly get our final image.
You can see how images emerge here.
The following is the entire sampling code needed to generate images:
def sample(model_inputs, latent_shape, unconditional_inputs, steps=12, renoise_steps=11, temperature=(0.7, 0.3), cfg=8.0):
with torch.inference_mode():
sampled = torch.randint(low=0, high=model.num_labels, size=latent_shape)
initial_noise = sampled.clone()
timesteps = torch.linspace(1.0, 0.0, steps+1)
temperatures = torch.linspace(temperature[0], temperature[1], steps)
for i, t in enumerate(timesteps[:steps]):
t = torch.ones(latent_shape[0]) * t
logits = model(sampled, t, **model_inputs)
if cfg:
logits = logits * cfg + model(sampled, t, **unconditional_inputs) * (1-cfg)
sampled = logits.div(temperatures[i]).softmax(dim=1).permute(0, 2, 3, 1).reshape(-1, logits.size(1))
sampled = torch.multinomial(sampled, 1)[:, 0].view(logits.size(0), *logits.shape[2:])
if i < renoise_steps:
t_next = torch.ones(latent_shape[0]) * timesteps[i+1]
sampled = model.add_noise(sampled, t_next, random_x=initial_noise)[0]
return sampled