Multimodal LLMs and DALL-E Explained

2023/11/26

I’m finding that writing here is a great way to make sure I have a complete understanding of various concepts. With that in mind, I decided to keep going along the LLM line. Scheduled programming on the Brunton textbook will probably resume soon.

I imagined there would be some cool trick to multimodal LLMs, but that doesn’t actually seem to be the case! If there is a cool trick, it’s the cool trick of realizing that you can do an unusual amount of things with transformers, including image processing, which itself lets you do a surprising number of things.

Image Input/Output

If you want your LLM to accept image inputs, all you have to do is generate embeddings for the image. A common way to do this is with vision transformers (ViT). Vision transformers have the idea to turn an image into a series of tokens by breaking it into 16x16 pixel patches. Each input token is a linear projection of a patch. Now you can take the output of one (one!) transformer and run it through a linear layer to get an embedding. One minor point is that the transformer’s sequence length needs to cover the whole image, so 1024 tokens for nonoverlapping 16x16 patches to cover a 512x512 image.

For image output, of course, you just use a diffusion model like DALL-E. In fact, you can be slightly more efficient, since text-to-whatever diffusion models need to use a transformer model to convert a sentence into an embedded vector. If you already have an embedded vector from your language model, you can just use the generation part of the diffusion model as an output head (similar to how you use a linear layer as an output head for next word prediction).

Video Input/Output

You can imagine doing video input by running ViT on each frame of a video to produce a series of frame embeddings, then passing those through another transformer to extract some information about a video (for instance, action recognition). In practice, you can actually do this! It’s also common to “temporally inflate” the ViT patches, meaning that the linear projection in the first step of ViT becomes a projection of (for example), a 16x16x3 cube, where the last dimension is time.

There are also video-based diffusion models(ex: VideoFusion), but they’re currently less good than the image ones. Still not awful though!

Audio Input/Output

For audio, you could imagine some 1D convolution over the audio stream (which is a series of intensities sampled at a very high rate, like 16 kHz or 44 kHz) to produce the tokenized inputs to the transformer. But in fact, there’s another very clever option, which is to produce a spectrogram of the audio and use ViT again.

We’re only just starting to see good audio-based diffusion models (ex: AudioLDM), but we are getting there!

Diffusion Models and DALL-E 2

I’ve been sort of hand-waving the generation aspect of each modality away by saying that diffusion models can do it, without explaining diffusion models. I’ll do this by explaining DALLE-2.

CLIP

But first, we have to take a little detour and talk about CLIP! CLIP is a neural network that learns to predict the probability that a natural-language caption is associated with an image, which I’ll call $P(z | y)$. CLIP is trained in batched mode, so it gets (for example) 10 images and 10 captions and must essentially learn to generate the full 10x10 confusion matrix. This turns out to be really useful in inference mode! Have 10 images and want to see which one is most catlike? Have 10 possible captions and want to see which one best fits an image? CLIP can do it. The loss function for learning a confusion matrix, in case you’re wondering, is called multi-class N-pair loss. CLIP was trained with a variety of architectures, but the best one seems to have used a transformer architecture as a text encoder and (you guessed it), ViT as the image encoder.

Diffusion

Diffusion models learn to progressively denoise inputs given priors. For example, you might send in a noisy image $x_{I} = y + z, z \sim N(0, \sigma)$ and the CLIP embedding $x_{E}$ of the denoised image, and the decoder tries to output $z$. In this notation, $y$ is the denoised image, so $y = x_{I} - z$. Diffusion models actually take this a little further and learn to progressively denoise images, meaning that more and more random noise is added until the last noisy input $x_{I}^{(n)}$ basically resembles random noise, and the decoder is trying to predict $x_{I}^{(n)} - x_{I}^{(n-1)}$. You can apply the concept of diffusion to any number of datatypes, including images, audio, video, and even vectors in an embedding space.

DALL-E 2

Once you have a CLIP network trained, you can train another neural network that takes CLIP text embeddings and uses them to predict the paired CLIP image embeddings. This first network is called the prior. The prior in DALLE-2 is typically a diffusion model. Architecture-wise, it’s typically a transformer network.

You also train a network that allows you to decode CLIP embeddings into regular images. Again, this is a diffusion model, and it takes as input a noisy image, the CLIP image embeddings predicted by the prior network, and optionally the CLIP text embeddings from the caption. Since you don’t want to do this in a very high-dimensional space, you train this model to output 64x64 images, and then two upsampling models (ADMNets - interestingly, not conditioned on the caption) to get you from 64x64 to 256x256 and then to 1024x1024. Interestingly, the decoder network breaks from tradition and is not a ViT network, but a CNN.

ImageBind

ImageBind is a super cool paper out of Facebook research that showed you can actually create a single embedding space for images, text, video, and audio! They use pretty much the techniques described above (specifically, DALLE-2 is used as the diffusion model for images). They don’t bother with video or audio generation. There are also very cool experiments with, for example, latent space arithmetic, where you can take the embedded vector for an image of a bird, add it to the embedded vector of the sound of waves, and use that to generate an image of the bird standing in water!

Diffusion Policies

Here’s a weird way to use diffusion: imagine you want to train a robot (potentially a high-dimensional one) to perform various tasks that it needs sensors to complete. For example, maybe the robot is a bimanual manipulator, meaning that it’s a system of two robotic arms with grippers attached. It has access to an image of the current workspace and its current joint state, and needs to unscrew the cap from a water bottle. If each arm is 7dof and each gripper is 1dof, then the control space is 16-dimensional. If an “action” is 10 time-sampled control vectors over the next couple of seconds, then the action space is 160-dimensional, which is pretty small relative to other things diffusion policies are good at (like 1024x1024 images - a $10^{20}$-dimensional space). In fact, it’s also useful in practice to predict both the states and actions simultaneously as in Diffuser; this doubles the size of your output vector (unless, of course, your action space is a trajectory through state space!), but that’s still totally fine.

One thing you might notice about this task is that the space of correct policies is multimodal (at least bimodal, since you can imagine gripping the bottle with either gripper and using the other one to unscrew the cap). An “average” of these modes, where each hand tries to compromise between grasping the bottle and grasping the cap, will clearly not work. This is a similar problem to the most common failure mode of image generation prior to high quality diffusion! Images tended to look smoothed-out and maybe blurry - they were a good average representation of the class of thing they were trying to portray, but didn’t look like any particular member of that class.

Architectures here tend to be very simple, as usual. 1-D CNNs and transformers (like 2 layers) are both common. Diffusion Policies claims that transformer-based architectures sometimes produce slightly smoother trajectories, but CNNs are less sensitive to hyperparameters. I’d like to promote Diffusion Policies as a paper that’s applied this technique very successfully in the real world. Old labmates have used it in manipulation tasks and confirmed that it actually does work quite well.