New Model Architectures

There are many advancements in new model architectures in AI domain. Let me overview these advancements in this post.

 

Linear Compression Embedding

LCE [1] is simply using a matrix to project one embedding matrix to another: AW =B, where W \in \mathbb{R}^{M \times N}, A \in \mathbb{R}^{B \times M}, and B \in \mathbb{R}^{B \times N}.

Pyramid networks, inception network, dhen, lce

 

Perceiver and Perceiver IO

Perceiver-based architectures [5,6] solve a pain point of traditional transformers: quadratic computation in terms of sequence length for computing KQV attention, where K/Q/V all have the shape batch_size x seq_len x emb_dim. Perceiver-based architectures creates a learnable V matrix which is independent of input sequence lengths, thus to reduce the attention computation to be linear. Thanks to linear attention computation time, within the same computation budget, it is also able to stack more attention blocks to improve the performance. There are three web tutorials that are helpful: [3,4,5].

Since Perceiver has done a good job of reducing attention computation to linear time (in terms of sequence length), I was wonder why recent LLMs did not, if any, adopt this architecture. A reddit online discussion shed some light [7].

 

Mamba

Mamba [9] is a new sequence modeling architecture which also achieves linear time based on so called “selective structured state space model”. Currently, it achieves SOTA in many tasks. The official implementation can be found here. There is not much material which can clearly articulate the theory behind Mamba. I have only found pieces of information here and there: [8, 10]. The way I understand about Mamba after watching [8] is:

  1. Mamba enjoys the best of both worlds of RNN and Transformers
      Training Inference
    RNN O(N) but can’t be parallelized because states need to roll out sequentially O(1) per token
    Transformer O(N^2) but can be parallelized  O(N) per token (see time analysis in [12])
    Mamba O(N) and can be parallelized O(1) per token
  2. According to basic definition of structured state space sequence (S4) models [10], Mamba has the state transition represented as differential equations:
    h'(t)=\mathbf{A}h(t)+\mathbf{B}x(t)
    y(t)=\mathbf{C}h(t)
    Note, h'(t) is the change (i.e., derivative) of the hidden state.
  3. The differential equations describe a real-time system but in sequence modeling we are interested in sequences with discrete time steps (i.e., tokens). “Discretization” in the paper means that we use approximation methods to derive new matrices \bar{\mathbf{A}} and \bar{\mathbf{B}} when we assume the input function x(t) becomes the sequence \{x(t\Delta)\}_{t=0,1,\cdots}, where \Delta is the step size. 
    There are two methods to do the discretization. [8] introduces the Euler’s method, which is simple for exposition. The paper uses zero-order hold, which is explained in more details in [13]. We use the Euler’s method for illustration, following [8]:
    1.  by definition of derivative we know that h(t+\Delta) \approx \Delta h'(t) + h(t)
    2. Substitute the state space model’s equations with step 1:
      h(t+\Delta) \newline \approx \Delta \left( \mathbf{A}h(t)+\mathbf{B}x(t) \right) + h(t)\newline=\Delta \mathbf{A} h(t)+\Delta \mathbf{B}x(t) + h(t) \newline= \left(\mathbf{I}+\Delta\mathbf{A} \right)h(t) + \Delta \mathbf{B} x(t) \newline = \bar{\mathbf{A}} h(t) + \bar{\mathbf{B}} x(t)
  4. State transitions during training can be parallelized without needing to wait for sequential rollout because the state transition model can be converted into convolutional formulation that only depends on the input sequences:
  5. A S4 mdoel is used for one input dimension. So if a token has 512 embedding size, there will be 512 S4 models in parallel. In comparison, transformers use attention heads where each attention head consumes all dimensions (or a group of dimensions) of a token.

 

Reference

  1. LCE explanation: https://fb.workplace.com/groups/watchrankingexperiments/permalink/2853044271454166/ (internal) and https://www.dropbox.com/scl/fi/0ygejam1mg2jhocfmgoba/watch-_-chaining-ranking-experiment-review-_-group-_-workplace.pdf?rlkey=4euhhg6f0h838ythg1na1amz8&dl=0 (external)
  2. https://medium.com/@curttigges/the-annotated-perceiver-74752113eefb
  3. https://medium.com/ml-summaries/perceiver-paper-s-summary-3c589ec74238
  4. https://medium.com/ml-summaries/perceiver-io-paper-summary-e8f28e451d21
  5. Perceiver: General Perception with Iterative Attention: https://arxiv.org/abs/2103.03206
  6. Perceiver IO: A General Architecture for Structured Inputs & Outputs: https://arxiv.org/abs/2107.14795
  7. https://www.reddit.com/r/MachineLearning/comments/tx7e34/d_why_arent_new_llms_using_the_perceiver/
  8. https://www.youtube.com/watch?v=8Q_tqwpTpVU
  9. Mamba: Linear-Time Sequence Modeling with Selective State Spaces: https://arxiv.org/abs/2312.00752
  10. https://srush.github.io/annotated-s4/?ref=blog.oxen.ai
  11. https://gordicaleksa.medium.com/eli5-flash-attention-5c44017022ad
  12. Llama code anatomy: https://czxttkl.com/2024/01/17/llama-code-anatomy/
  13. https://en.wikipedia.org/wiki/Discretization#Derivation

 

Leave a comment

Your email address will not be published. Required fields are marked *