This is the first time I have read llama2 code. Many things are still similar to the original transformer code, but there are also some new things. I am documenting some findings.
Where is Llama2 Code?
Modeling (training) code is hosted here: https://github.com/facebookresearch/llama/blob/main/llama/model.py
Inference code is hosted here: https://github.com/facebookresearch/llama/blob/main/llama/generation.py
Annotations
There are two online annotations for llama2 code which I think are useful:
- [Ref1] Deciphering the LLAMA2 Code: Unraveling the Secrets of a Language AI Marvel: https://www.linkedin.com/pulse/deciphering-llama2-code-unraveling-secrets-language-ai-ayoub-kirouane/
- [Ref2] The Annotated LLaMA: https://medium.com/@nishantbhansali80/the-annotated-llama-fa183943b34b
While the two annotations are useful, I still need some external references for parts I don’t understand:
-
precompute_freqs_cis
is a function for computing rotary embeddings. Ref2 has a better explanation than Ref1. - K/V cache (
self.cache_k
andself.cache_v
inAttention
class) is only meaningfully useful in inference (next token prediction). First, we need to build a mental model how inference works. Suppose we have a batch of prompts to start the inference (of the same length for simplicity). The transformer model will consume the batch of prompts and generate the first tokens. Then, each next token will be generated by consuming the prompts + previously generated tokens. If we don’t have K/V cache, you can foresee that K/V will be repeatedly computed for previously generated sequences for each next token.
K/V cache eliminates the need to recompute K/V after predicting every next token. With K/V cache,self.cache_k
andself.cache_v
will store the current batch’s K/V and K/V of the full previously generated sequences will be fetched fromself.cache_k
andself.cache_v
(https://github.com/facebookresearch/llama/blob/main/llama/model.py#L276-L283):
To help understand more, you can see that theforward
function ofAttention
acceptsstart_pos
as an argument (https://github.com/facebookresearch/llama/blob/main/llama/model.py#L265). After the first batch which contains prompts, each following batch will contain single tokens that are generated from the last batch. Therefore,start_pos
will be +1 incremental and every following batch’sseq_len
will become 1. One can reference llama2 inference code: https://github.com/facebookresearch/llama/blob/main/llama/generation.py#L162-L212 for how a model really gets called in the inference time.
A side note is that K/V cache reduces FLOPs but does not reduce overall decoding time complexity. Here is a table (similar to this post’s table) showing the FLOPs of each sub-step when predicting each new token:
w/o K/V cache, x needs to have shape ( batch_size * seq_len * hidden_dim
)with K/V cache, x has shape ( batch_size * 1 * hidden_dim
)Convert x to K/V by xW_K, xW_V O(batch_size * seq_len * hidden_dim * hidden_dim)
O(batch_size * 1 * hidden_dim * hidden_dim)
K/V cache only saves this part’s FLOP
Convert x[:, -1] to q by x[:, -1]W_Q O(batch_size * hidden_dim * hidden_dim)
O(batch_size * hidden_dim * hidden_dim)
p = softmax (qK^T) / sqrt(d) O(batch_size * seq_len * hidden_dim * hidden_dim)
O(batch_size * seq_len * hidden_dim * hidden_dim)
Overall time complexity is still dominated by softmax
a = pV O(batch_size * seq_len * hidden_dim)
O(batch_size * seq_len * hidden_dim)
Convert a to output aW_O O(batch_size * hidden_dim * hidden_dim)
O(batch_size * hidden_dim * hidden_dim)
- K/V/O linear transformation (https://github.com/facebookresearch/llama/blob/main/llama/model.py#L221-L234) is done using TensorParallelism (
ColumnParallelLinear
orRowParallelLinear
), which is introduced in https://arxiv.org/pdf/1909.08053.pdf and explained in https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/neuronx-distributed/tensor_parallelism_overview.html and https://www.cnblogs.com/rossiXYZ/p/15871062.html#0x04-rowparallellinear. At a high level, TensorParallelism chunks original large matrices into smaller ones, put them on different GPUs, and collect results only when needed, so as to speed up matrix operation.