This post is an attempt to explain Facebook’s Llama LLM to myself.
Part 1: Transformer Models, Architecture
Tokenization and Initial Embedding
First, the input string is split into tokens. These tokens are often “subword tokens,” meaning that word parts (like -ed and -tion) get their own tokens. This is useful in the case that a word is out-of-vocabulary (OOV). If the input is nonsense, the tokens may be individual characters. The tokens are then looked up in a dictionary of embeddings. This dictionary is not static, however! It is typically randomly initialized at the start training, then the vector associated with each token is learned as part of the training process
Output: $d$-dimensional vector embeddings for $n$ tokens
Positional Encoding
Positional encoding is a way to encode relative positions of symbols. If embeddings are $d$-dimensional and there are $n$ tokens, then the positional encoding at position $pos$ will also be a $d$-dimensional vector. This vector is found as:
$$PE_{pos}^{(2i)} = \sin \left( \frac{pos}{10000^{\frac{2i}{d}}} \right)$$ $$PE_{pos}^{(2i+1)} = \cos \left( \frac{pos}{10000^{\frac{2i}{d}}} \right)$$
Here, $i$ is the dimension of the vector.
Having found the positional encodings, we incorporate them into the word information by elementwise summation. This, combined with the learnable embedding library, lets us learn how tokens tend to relate to tokens in other relative positions.
Output: $d$-dimensional embeddings (with semantic + positional information) for $n$ tokens
Self-Attention
Each embedding is now used to create a query and a key vector. These vectors are found by multiplying the embedding by a learned matrix.
For each token, we compute pairwise attention scores $a_{i,j}$ as the dot product of the current token’s query vector with every other token’s key vector. The attention scores are then normalized using softmax, ensuring that they sum to 1. To ensure stability when computing the gradient, it’s common to scale the dot product by $\frac{1}{\sqrt{d_k}}$, where $d_{k}$ is the dimension of the key (and query) vector. This prevents large values from essentially nulling out the rest of the softmax output. Either way, we now have pairwise attention scores $a_{(i,j)}$ for all tokens.
For each token, we also compute a value vector. This is computed similarly to the key and query vectors (the product of a learnable matrix with the embedding). The output vector associated with a token can then be found as the weighted sum of all other tokens’ value vectors, where the weight is given by the pairwise attention score between the current token and each other token:
$$v_{out}^{(i)} = \sum\limits_{j = 0}^{n} a_{(i,j)} v^{(j)}$$
Typically, this output will then be normalized, passed through a linear layer, etc.
Masked Attention
Masked attention is used to enforce causality on embeddings (that is, each embedding is influenced only by the embeddings that precede it in time). This simply means that you don’t have to compute half of your attention scores, and it’s useful for tasks like next word prediction (which are causal).
Multi-Head Attention
A variant on self-attention simply performs this process multiple times, with multiple (independent, learnable) $Q$, $K$, and $V$ matrices. Each of these independent self-attention processes produces an output vector; these vectors are then concatenated and passed through a final linear layer.
Grouped-Query Attention
In normal self-attention, you need to compute $O(n^2)$ attention scores. In grouped-query attention, you choose a group size $q$ and compute query vectors for each group. You still compute key and value vectors for each token, so your computational efficiency gets scaled by $q$ (ie, your big-O time is now $O(\frac{n^{2}}{q})$). Obviously, you can also do grouped-query attention on multi-head attention. This allows you to effectively train a context window $q$ times longer at the same training time (presumably with a slight performance hit, but maybe not once you consider that this might actually help prevent overfitting).
Rotary Positional Encoding: Positional Encoding for Shape Rotators
RoPE is a variant on positional encoding where sequential pairs of dimensions in the $k_{i}$ and $q_{i}$ vectors are considered as $(x, y)$ pairs on a plane (for instance, $(k_{i}^{(0)}, k_{i}^{(1)})$ would be considered a point). RoPE’s core concept is that instead of doing regular positional encoding, you choose a phase shift (or angle of rotation) proportional to the token position and literally rotating each “point” in the vector by that amount. Surprisingly, this works really well.
This section lives where it does because RoPE doesn’t encode position by operating directly on the embedding, but rather on the key and query vectors.
Output: A set of $n$ “contextualized embeddings” of dimension $d_{model}$. Unlike the prior embeddings, these are particular to the exact input (every token contributes to the embedding of every other token).
Output Head
Let’s say we’re trying to do next-word prediction. The output head will typically be a linear layer and a softmax, designed to map sequences to probability distributions over the vocabulary (that is, the key set of the token library).
SwiGLU
SwiGLU is a type of nonlinearity used in Llama. You might be familiar with applying a ReLU as a nonlinearity:
$$\texttt{FFN}_ \texttt{ReLU} = ReLU(xW_ {1} + b_{1})W_ {2} + b_ {2}$$
Instead of using ReLU as an activation function, you can also use it as a gating function (like in LSTMs). This is called ReGLU, and it composes feedforward and activation-esque behavior:
$$\texttt{ReGLU}_ \texttt{ReLU} = ReLU(xW_ {1} + b_ {1}) * (xW_ {2} + b_ {2})$$
Where $*$ here is the elementwise multiplication operator.
Instead of ReLU, there’s a type of function called a “swish”, which is an approximation of the ReLU with a nonzero gradient for negative values close to zero. If you replace the ReLU in ReGLU with the swish function you get SwiGLU:
$$\texttt{SwiGLU} = swish(xW_{1} + b_{1}) * (xW_{2} + b_{2})$$
In the paper “GLU Variants Improve Transformer,” Noam Shazeer shows that GLU-based functions have empirically better performance than FFN functions. Since SwiGLU combines the features of good performance and trainability (thanks to the swish function replacing the ReLU), it is often used as a nonlinearity in transformer models.
RMSNorm
RMSNorm is a method of layer normalization. It’s very simple: if a vector $x_{1..n}$ passes through RMSNorm, each item is scaled by the inverse of the root mean square of $x$:
$$x_{(i)} = \frac{x_{(i)}}{\sqrt{\frac{1}{n}\sum\limits_{i=1}^{n} x_{i}^{2}}}$$
RMSNorm scales the vector so its norm is a constant value. This helps prevent gradients from vanishing or exploding during backprop.
Llama Architecture Overview
Llama 2 is scary simple:
- One block contains an attention function and a feedforward function.
- An attention function does RoPE and then normal self-attention, using both grouped-query attention (for larger model sizes) and multihead attention. The attention function has the normal linear layer at the end that you expect from multihead attention.
- A feedforward function is just SwiGLU passed through one additional linear layer, with no bias weights. $$\texttt{FFW}(x) = (swish(xW_1) * (xW_2))W_3$$
- The block is composed like this: $$\begin{align} h &:= \texttt{Attn}(\texttt{RMS}(x)) \\ \texttt{Layer}(x) &= x + h + \texttt{FFW}(\texttt{RMS}(h)) \end{align}$$
- Llama 2 is 32 blocks, then a final linear layer
Part 2: Training
Loss Function
Since next-word prediction is a huge classification task, the loss function is cross-entropy loss. Easy!
Data
Facebook is a bit coy about its training data, claiming that they use “a new mix of data from publicly available sources” (though they do emphasize they didn’t use any data from Facebook-owned entities). I assume this means some combination of:
- Facebook used text from illegal/unlicensed sources and are being strategically ambiguous about that so they can’t be sued
- Facebook figured out that your dataset is more important than your architecture and want to preserve some competitive edge
However, Llama 1 does tell you about their dataset, which is a combination of the Common Crawl project, Github, Stack Exchange, Wikipedia, Project Gutenburg, and ArXiv.
Optimizer and Hyperparameters
Optimizer is AdamW, with $\beta_{1} = 0.9, \beta_{2} = 0.95, eps = 10^{-5}$ They use a cosine LR scheduler, with a warmup of 2000 steps, and decay the final LR to 10% of the peak LR. They use weight decay of 0.1 and gradient clipping of 1.0.
Conclusion
One of the craziest things about Llama is how easy it is to make something like it. The data is publicly available. The architecture is dead simple. There are a bunch of random tricks, and maybe you have to hit the right combination of tricks to make the whole thing work, but now that we’ve done it it’s extremely easy to replicate. This makes me much more glad that places like OpenAI are not open-sourcing their code. Given access to enough GPUs, training something like Llama from scratch is realistic for an ambitious and talented undergrad to do within a year.
This also means that I’m not very impressed with China’s LLM capabilities. Catching up with something like Llama is a question of purchasing GPUs. If anything, the fact that China apparently couldn’t do anything like this until recently is more significant to me - despite the strategic importance of LLMs, no Chinese company seemed to have gotten good at them until the ChatGPT interface proved how concretely and immediately useful they were. That is the behavior of a state that’s willing to lag behind the cutting edge in order to realize the cost efficiencies of letting someone else do your R&D (or worse, has few R&D capabilities in this space to mobilize).
One a more technical level, something this makes me appreciate is how much the whole transformer model relies on the technical difference between “infinite” and “finite but really big.” The token dictionary for GPT is apparently tens of thousands of elements, so your next-token prediction is potentially a 30k-dimensional vector output. I don’t think I ever would have imagined something like this would be feasible. My instinct would have been to have one token per English phoneme or something, plus all letters, numbers, and symbols for edge cases, which is something like 150 dimensions. But there’s an interesting lesson here: complexity was moved from the problem we had to solve into the model itself. Do too much of this and you have a huge model and not enough data to train it. But apparently the right amount of that to do was a few tens of thousands of tokens - 100x what I would have imagined. Cool!