Llama code anatomy

This is the first time I have read llama2 code. Many things are still similar to the original transformer code, but there are also some new things. I am documenting some findings.

Where is Llama2 Code?

Modeling (training) code is hosted here: https://github.com/facebookresearch/llama/blob/main/llama/model.py

Inference code is hosted here: https://github.com/facebookresearch/llama/blob/main/llama/generation.py

Annotations

There are two online annotations for llama2 code which I think are useful:

  1. [Ref1] Deciphering the LLAMA2 Code: Unraveling the Secrets of a Language AI Marvel: https://www.linkedin.com/pulse/deciphering-llama2-code-unraveling-secrets-language-ai-ayoub-kirouane/
  2. [Ref2] The Annotated LLaMA: https://medium.com/@nishantbhansali80/the-annotated-llama-fa183943b34b

While the two annotations are useful, I still need some external references for parts I don’t understand:

  1. precompute_freqs_cis is a function for computing rotary embeddings. Ref2 has a better explanation than Ref1.

  2. K/V cache (self.cache_k and self.cache_v in Attention class) is only meaningfully useful in inference (next token prediction). First, we need to build a mental model how inference works. Suppose we have a batch of prompts to start the inference (of the same length for simplicity). The transformer model will consume the batch of prompts and generate the first tokens. Then, each next token will be generated by consuming the prompts + previously generated tokens. If we don’t have K/V cache, you can foresee that K/V will be repeatedly computed for previously generated sequences for each next token. 
    K/V cache eliminates the need to recompute K/V after predicting every next token. With K/V cache, self.cache_k and self.cache_v will store the current batch’s K/V and K/V of the full previously generated sequences will be fetched from self.cache_k and self.cache_v (https://github.com/facebookresearch/llama/blob/main/llama/model.py#L276-L283):
    To help understand more, you can see that the forward function of Attention accepts start_pos as an argument (https://github.com/facebookresearch/llama/blob/main/llama/model.py#L265). After the first batch which contains prompts, each following batch will contain single tokens that are generated from the last batch. Therefore, start_pos will be +1 incremental and every following batch’s seq_len will become 1. One can reference llama2 inference code: https://github.com/facebookresearch/llama/blob/main/llama/generation.py#L162-L212 for how a model really gets called in the inference time.
    A side note is that K/V cache reduces FLOPs but does not reduce overall decoding time complexity. Here is a table (similar to this post’s table) showing the FLOPs of each sub-step when predicting each new token:
      w/o K/V cache, x needs to have shape (batch_size * seq_len * hidden_dim) with K/V cache, x has shape (batch_size * 1 * hidden_dim)
    Convert x to K/V by xW_K, xW_V O(batch_size * seq_len * hidden_dim * hidden_dim)

    O(batch_size * 1 * hidden_dim * hidden_dim)

    K/V cache only saves this part’s FLOP

    Convert x[:, -1] to q by x[:, -1]W_Q O(batch_size * hidden_dim * hidden_dim) O(batch_size * hidden_dim * hidden_dim)
    p = softmax (qK^T) / sqrt(d) O(batch_size * seq_len * hidden_dim * hidden_dim)

    O(batch_size * seq_len * hidden_dim * hidden_dim)

    Overall time complexity is still dominated by softmax

    a = pV O(batch_size * seq_len * hidden_dim) O(batch_size * seq_len * hidden_dim)
    Convert a to output aW_O O(batch_size * hidden_dim * hidden_dim) O(batch_size * hidden_dim * hidden_dim)
     
  3. K/V/O linear transformation (https://github.com/facebookresearch/llama/blob/main/llama/model.py#L221-L234) is done using TensorParallelism (ColumnParallelLinear or RowParallelLinear), which is introduced in https://arxiv.org/pdf/1909.08053.pdf and explained in https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/neuronx-distributed/tensor_parallelism_overview.html and https://www.cnblogs.com/rossiXYZ/p/15871062.html#0x04-rowparallellinear. At a high level, TensorParallelism chunks original large matrices into smaller ones, put them on different GPUs, and collect results only when needed, so as to speed up matrix operation.

Leave a comment

Your email address will not be published. Required fields are marked *