Anifusion: diffusion models for anime pictures

Generations of portraits of renowned characters (in order, cherry-picked): Ash Ketchum, Link, Tohsaka Rin, Saber, Naruto, Hatsune Miku

Intro

Continuing the topic of image processing in Machine Learning, this post is going to cover an experiment of building a diffusion-based image generation from scratch, trained on commodity hardware.

There was an explosion of diffusion-based image models recently, including proprietary Dall-E 2 (OpenAI, available as a paid service), Imagen and Parti (by Google, not available for general public at all), and most recently — publicly available Stable Diffusion (slightly inferior to the proprietary counterparts, but checkpoint is publicly available. They all are focused on generating arbitrary images from text prompts, with LAION being the standard training dataset.

The downside of these models is the cost of training: it takes an enormous amount of GPU hours, ranging in millions of dollars for training a single model. It is not an issue when fine-tuning existing models, but if someone wants to e.g. train a model with a new architecture, it is practically impossible unless they’re very rich, or working at a big tech company.

I decided to do this project as an attempt to see if restricting image space enables us to train models with commodity hardware (in my case —single RTX 3090). I chose anime-style drawings as the target dataset, hoping that they’re simpler to generate than arbitrary internet images. I started this work before Stable Diffusion was released (otherwise, from the standpoint of sample quality, fine-tuning it might be a better idea; I might try this in the future, but this work is focused on training the model entirely from scratch).

Dataset

There is a large-scale dataset available with anime drawings — Danbooru2021 by Gwern Branwen. This is a very good fit for this project, however, there are two catches:

  • Instead of textual descriptions for images, there are sets of tags. These sets are rather descriptive and contain a lot of tags, but this means we cannot reuse existing text encoders (like Clip in StableDiffusion).
  • A big fraction of the dataset is NSFW. Moreover, some of this NSFW is very … fringy, and can be upsetting to many people. Luckily, the images contain the “explicitness” rating, which we can feed as an input to the conditioning model, to restrict the generation of NSFW images. Another approach would be to filter out such images from the training altogether (like DallE/Imagen are trying to do), but I believe this approach is flawed: making models blind to something is strictly worse than making them aware of it, with a handle to control it, like in this case.

There are more insights into the dataset structure in the original description, so we will not repeat it here. Instead, let’s do some basic statistical exploration before going to model details.

First of all, the dataset contains around 5M images, but given the enormous size (I understand it is tiny though compared to LAION), we take an 1.5M subsample, stratifying by score to get more higher-quality images.

Next, we resize the images to the (512, 512, 3) shape, padding with black when necessary. Note that this is different from what models like StableDiffusion are doing — they do resize, but don’t do padding. As a result, we get 300GB of resized images.

Before training on them, let’s take a look at the metadata. There are around 250K different tags, we take the top 2500 by frequency for conditioning (looking back, perhaps it would’ve been a good idea to use more). Here are the top ones:

There are three NSFW ratings — “explicit”, “questionable” and “safe”, we’ll also use these for conditioning.

Many images have the “source” URL annotated. We won’t use it later, but it is interesting to see where these anime images are coming from. Here are the top sources:

Let’s look at the dynamic by time:

Fraction of total images coming from a domain, by time

Note the pixiv/pximg switch — it is the same source, they just changed their domain structure. However, we see a trend of Pixiv cumulatively going down in favor of Twitter (!). Let’s slice it a bit more, in particular, look only at NSFW images:

Fraction of NSFW images coming from a domain, by time

OK, this is a bit surprising to me — Twitter doesn’t strike me as an image hosting platform, let alone NSFW image hosting platform (given that it is mainstream-ish), but this is what the data says.

Anyhow, let’s move on. The final piece of the metadata we’ll need is the score: it is the balance of upvotes/downvotes by actual human users. It is distributed somewhat log-normally, with a weird tail in the negative range. We could use it as a proxy to image quality, however, there are two issues with this:

  • It correlates with age of the image (r=0.4), and not in the direction one might suspect: older images have lower scores! This might mean that the amount of users up-/down-voting the images increases over time, and these users primarily look at new images.
  • It strongly correlates with tags/topics. E.g. the average score of images with tag 2girls is 2x higher than 2boys.

To mitigate these issues, we come up with a notion of “adjusted score”: we train a prior model F(timestamp, tags) ~score (it has 0.64 correlation with the actual score), and take adjusted_score = score — F(timestamp, tags).

Note that this approach has its downsides: not all tags are topical, e.g. it is natural for tags like “highres” to correlate with image quality, so this “demographic parity” enforcement is not perfect. However, it provides some additional way to control model outputs later during sampling, without introducing topic biases if one wants to sample high-score images.

This concludes the dataset part: we have 1.5M images with tag annotations, explicitness annotation and two score annotations (for the purposes of feeding them to the transformer, we discretize them, replacing them with the rounded percentiles).

Model

We aren’t going to go into details on how diffusion models work — there are plenty of good tutorials about this already. On the high level, the structure of the model is similar to Stable Diffusion, so for people unfamiliar with it, it might be a good idea to read something before diving into this section. Some pointers: wiki article, some tutorial with lots of pictures, huggingface tutorial; for a more formal description, one can read the Latent Diffusion paper.

We’ll primarily focus on how this work is different from Stable Diffusion. Let’s go right into technical details.

We’re using Latent Diffusion architecture and codebase. The model consists of an encoder+decoder for transforming the image in the smaller latent space, and the actual conditional diffusion model. For the encoder/decoder, we took one of the pretrained VQGAN models, and fine-tuned on the anime images (both encoder and decoder). There are two differences from the Stable Diffusion here:

  • We use VQVAE, unlike the “normal” VAE with KL-based regularization in Stable Diffusion. I don’t know why the KL one was chosen for SD, as the original Latent Diffusion paper seems to imply that the VQ version might have higher sample quality:

Interestingly, we find that LDMs trained in VQ-
regularized latent spaces sometimes achieve better sample
quality, even though the reconstruction capabilities of VQ-
regularized first stage models slightly fall behind those of
their continuous counterparts

  • I started with the 4x encoder compression (256 -> 64), however, when StableDiffusion came out and was producing awesome results with higher compression (512 -> 64), we felt embarrassed to release the model which generates smaller images, and fine-tuned the encoder/decoder, prepending/appending additional blocks to increase compression factor.

Next goes the actual conditional diffusion model. We followed the standard approach, except that we had to replace the conditioner. Since we have tags not text, we cannot use conditioners like Clip or pre-trained language models. So we train a BERT-Base-sized encoder transformer (12 layers, 768 embedding dimensions), which has no positional embeddings (tags are sets, not sequences, so their order doesn’t matter; zeroing positional embeddings enforces this invariant).

We also have to do some data augmentation with dropout on conditions. To produce high-quality samples, diffusion models typically require classifier-free guidance, which means the model should be able to run in an “unconditional” mode. We follow the standard practice and remove all tags on 10% of training examples. Additionally, we remove some tags on the additional 10% of examples, and most of the tags on yet another 10%: the goal was to make the model work on short out-of-distribution prompts like “1girl, solo” (real tag sets have many more tags). It didn’t help that much though, so additional techniques had to be employed at inference time.

Besides these changes, the training process is standard for diffusion models. The model itself is roughly 2x smaller than SD (we chose hyperparams before SD was published), making it runnable with less VRAM. We trained the model for a total of 40 GPU-days (of RTX 3090), making it roughly 200 times cheaper than Stable Diffusion. However, sample quality after 7 GPU-days was already decent.

Sampling

Here comes the biggest divergence from the “normal” diffusion models. The conditioning in our model is based on a set of tags, which is typically rather large. We cannot expect users to enter 30 tags for a prompt, so some changes are necessary.

In the training process, we tried to add dropout to make short prompts plausible, but the results were not satisfactory (below are some examples).

Samples for “1girl, solo” without augmentation (no cherrypicking)

So we had to come up with something to mitigate it. This “something” is yet another model — the one which “augments” the user prompt by adding more tags. One can think of it in the following way: the model has to come up with a bunch of details for such short prompts, and doing it during the diffusion process is hard (for the same reason unconditionally generated images are much worse than conditional ones). So we do it before the actual diffusion, generating a set of consistent concepts with an additional model.

This additional model is a GPT-like autoregressive transformer (decoder only). For simplicity, we train it the usual way, ignoring the fact that tags are unordered (and instead random-shuffling them during training for data augmentation).

During inference, we take the user prompt and sample additional tags from it, until we get a fixed number (50) of tags. This means that at the sampling stage, we have two more knobs to control (tag count and sampling temperature). We use the augmented tags to create a conditioning and run the usual diffusion sampling. Note that we do it only for the prompt itself, but not for the negative prompt from classifier-free guidance (the one which is typically empty). By the way, surprisingly, sampling random sets of tags for the negative prompt still yields decent results.

This process yields much better results (both higher quality and more diverse), e.g. examples for the prompt above:

Samples for “1girl, solo” with augmentation (no cherrypicking)

Note that the conditioning includes the NSFW-ness and score percentiles, so the autoregressive transformer will be affected by it (e.g. it’ll sample NSFW-related tags only if someone requests NSFW images).

For sampling hyperparameters, we observe that classifier-free guidance scale around 8 and steps ≥ 50 are optimal in terms of image quality (with ODE-based samplers, DDIM/DDPM are of course different). This is in contrast with Stable Diffusion, where a lower number of steps (around 30) already produces decent results (our model is faster though, because it is smaller).

Results

Note: for the sake of this section, we don’t cherry-pick images, otherwise one can find good-looking images for any prompt with enough attempts.

Unlike the images which are typically used to showcase other diffusion models, our dataset mostly consists of drawings of human-like objects, so we don’t expect to generate beautiful landscapes. Let’s start testing it on the easiest prompts — human portraits. Images for the following prompt: “1girl, solo, smile, portrait

One thing which immediately clearly stands out: generated backgrounds are very simple. Another — something is off with the fingers of the characters, which is going to be the recurrent pattern with the generated images. Interestingly, similar pattern occurs in Stable Diffusion — it might be the case that such details are difficult in general for smaller diffusion models.

The next test would be to make the problem slightly more complex - let’s generate the full-body pictures. Here is the prompt: “1girl, solo, smile”.

There is nothing immediately interesting, besides the observations from the portrait-only images (weirdness with hands continues).

Next comes a more challenging test: can we use the prompt to control the pose of the character? Stable Diffusion usually fails with such tasks, yielding exquisite body horror, and making users do lots of cherry-picking. Our model mostly seen human-like characters, so there is a hope it can do better. There is a bunch of modifiers for this in the tag set, below are samples for “all_fours”, “sitting”,standing”.

One can see that it does unexpectedly well with these, perhaps the fact that position space is discrete also helps. The next test is to draw more than one character at once. Let’s try “2girls”:

We see that there is a bias towards certain actions, perhaps introduced by the prompt augmentation. Can we make a neutral pose? It seems to be possible by adding “kiss” to the negative prompt, and adding “looking_at_viewer” to the positive one (side note: prompt augmentation doesn’t respect the negative prompt, so the same tag can end up in both):

What about “composability” of objects in the image? We can try human/human and human/object versions. For human/human, let’s try hug, as this is a rather complex action:

Interestingly, there are lots of tags in the set which control human/human compasibility, but unfortunately, many of them are NSFW-related, so this is not very interesting. Let’s try human/object interaction, e.g. “1girl, holding_food, apple”:

OK, this seems to work pretty decent, although it would be nice to test some more complex cases — however, tag-based prompting really restricts us here.

A curious reader might’ve noticed that we used 1girl/2girls tags up until now, but not their male counterparts (1boy/2boys). Why is this? There are several issues — for one, apparently, these “male counterparts” correlate with NSFW’ness (in presence of other characters in the image). But let’s restrict the examples to “solo”, what do we see (e.g. “1boy, male_focus, solo, portrait, smile”)?

These don’t look like typical male characters. What happened? Here is the interesting catch: we were adding “Score Percentile 100%” modifier to our generations to get higher-quality images, let’s remove it and try the same prompt:

This did help in terms of matching tags, but removing the quality modifier clearly is not helping. Perhaps playing with the score modifiers/finding some helpful tags can improve these results.

We can continue exploring the properties of the model (e.g. how well does it generate particular characters, or how can it imitate styles of particular artists (summary — both work quite well)), but for the sake of brevity, let’s skip this: it is already quite clear that the model is reasonably good, despite low computational resources invested into it, which was the main point of this experiment.

Let’s just mention one last thing: textual inversion. This is the technique to teach models new concepts with just a few examples of images of these concepts. We’ve tried it with character and artist concepts, both work (albeit somewhat worse than in case of Stable Diffusion), but with a caveat: it is very sensitive to the “prompt templates” being used, these have to be sampled from the real distribution, sort un-augmented prompts produce very poor results.

Technical: integration with existing tools

Since the rise of Stable Diffusion, lots of amazing tools came out to play with it. Since this work is very similar to SD architecturally, it makes sense to convert the results to a format which is compatible with these tools.

It turns out the conversion itself is quite straight-forward and simple, the exported model can be found at HuggingFace. However, the major difference is prompt conditioner — it is tag-based, includes prompt augmentation, and it makes it very different from the Clip-based conditioning in SD.

We hacked the codebase of SD a bit to include this conditioner, the resulting fork can be found here. A more challenging part is modifying relevant tools, since they (1) hack inside SD themselves to achieve various effects (like up-/down-weighting some terms in the prompt), and (2) require the prompt augmentation controls to be propagated throughout their stack.

For these reasons, we focused on one of the most popular tools — Stable Diffusion web UI. The hacky fork is available here. Not all functionality works though (some is obvious — “face fix” wouldn’t work on anime images; some — less obvious, e.g. up/down-weighting of individual terms just isn’t implemented).

In the future, it would be interesting to fine-tune the actual Stable Diffusion on this dataset (keeping the prompt augmentation) with some form of alignment of the tag-based conditioner and the original Clip conditioner. This should provide a much better out-of-box compatibility with existing tools.

Broader impact

This experiment shows that very decent models producing diverse images can be trained using two orders of magnitude less resources, if we restrict the training set. This can speed up future research in this direction, allowing to try out different architectures/training setup much faster and cheaper: not requiring millions of dollars to train these models opens doors for independent researchers to make contributions to this area. In fact, we’d advocate using this dataset in the research community more broadly, as opposed to datasets like LAION, which take a lot of resources to train on (or low-modality datasets like CelebA).

The model can produce NSFW images if asked to, including fringy NSFW. In this sense it is not different from the existing world, where artists produce and share it (of all places, according to analysis above, on Twitter). Generative models can make the production process easier, which might upset puritans, but all in all should have a positive effect on the world. If people want to see such images, they’ll see them one way or another, and reducing the amount of resources which goes into producing them is good. Also, if generated NSFW images evict “real” photographic ones, it’ll affect the porn industry, which is infamous for employing shady methods of producing such real images.

Pointers