Learn GPU Optimization

It has been a while since I learned GPU knowledge. I am going to keep updating more recent materials for ramping up my GPU knowledge.

FlashAttention [1]

We start from recapping the standard Self-Attention mechanism, which is computed in 3-passes:
Q, K, V \in \mathbb{R}^{N \times d}
S = QK^T \;\;\quad shape=(N \times N)
P=softmax(S)\;\;\quad shape=(N \times N)
O=PV\;\;\quad shape=(N \times d)

Notes:

  1. The shape N and d represent the sequence length and internal dimension, respectively. Q=W_k X, K=W_k X, V=W_v X, which transform token embeddings X into a projected latent space.
  2. P is the attention matrix. P_{i,j} represents how much attention token i should pay for j. In normal LLM tasks, we will apply a causal mask so that only P_{i,j} \;(i>j) is valid, because a token can only pay attention to all previous tokens. The softmax() operator is applied per-row of S.
  3. The output O represents the weighted values from other tokens per position. O will be fed to feed forward layers to be transformed into output space (from shape (N\times d) to (N \times {vocab\_size})). We omit the part after self-attention.

How GPUs work is that GPUs have a massive number of threads to execute an operation (called a kernel). Each kernel loads inputs from HBM to registers and SRAM, computes, then writes outputs to HBM. HBM has larger storage but slower speed. In practice, the times of HBM accesses play a non-negligible role in GPU execution speed. In standard self-attention, it requires O(Nd+N^2) HBM accesses, because we need to load Q, K, and V from HBM and we need to read/write of S and P, each of shape (N\times N). A detailed breakdown of standard self-attention HBM accesses is as below (illustrated by Gemini): Let’s dive into the step of computing softmax. Even computing softmax has many details. First, in real world, we usually need to compute a safe softmax version, where we subtract each input to the softmax with the maximum input to avoid potential overflow. (Diagrams from [3]) In its vanilla implementation, we need to have three passes: the first pass computes the maximum of the input, the second pass computes the denominator of the softmax, and the third pass computes the actual softmax. Asymptotically these three steps require O(N^2) time but in reality the constant factor of the N^2 computation is also important. SRAM is typically too small to hold the N^2-size S=QK^T result. So we either need to save S into HBM in the first pass and load it in the following two steps, or even need to recompute QK^T, depending whichever is faster. That means this three-passes safe softmax algorithm does require 3N^2 computation time / accesses to HBM. It turns out that we can turn the 3-passes algorithm into 2-passes. This is something called online softmax. As we traverse each x_i, we can record the maximum value so far and the accumulative softmax denominator. The accumulative softmax denominator from the previous x_{i-1} can be easily scaled whenever a new maximum is found at x_i. With this 2-pass online softmax algorithm, self-attention can also be computed in two-passes: However, we can do better by finding that \mathbf{o}_i can also be computed “online” together with the running maximum value m_i and running accumulative softmax denominator d_i'. That’s how we end up with FlashAttention which requires one pass to compute the output vectors! In reality, a further optimization is to load Q/K/V in blocks so that the blocks of Q, K, V, and O can occupy roughly the full SRAM memory at one time. That’s why we see block size is set at B_c=\lceil \frac{M}{4d} \rceil or B_r=min\left( \lceil \frac{M}{4d}\rceil, d\right) in the algorithm. Now, we examine how to do backward computation in FlashAttention. We start from examining the standard attention backward pass: (To clarify the notations, dO is the gradient of the loss \mathcal{L} with respect to the attention output matrix, i.e., dO=\frac{\partial L}{\partial O}, which has the same shape as O. Similarly for dQ, dK, and dV. We also assume dO is already available as the backpropagation computation has been done from \mathcal{L} to O) Let’s try to understand this standard backward pass:

  1. With O = PV, due to matrix calculus rules, we have dV = P^T dO.
  2. We need to compute and record dP because it will be used in computing dQ and dK. Due to matrix calculus rules, we have dP=dO V^T.
  3. dS = \frac{\partial L}{\partial S}= \frac{\partial L}{\partial P} \frac{\partial P}{\partial S}. There is a well-known mathematical result stating that the Jacobian matrix of softmax can be computed by J(P_i) = J\left(softmax(S_i)\right)=\left[ \frac{\partial P_i}{\partial s_{i1}}, \cdots, \frac{\partial P_i}{\partial s_{iN}} \right]=diag(P_i) - P_i P_i^T, where P_i is a row of P. With some simplification, we reach to d S_i = d P_i^T J(P_i)= P_i \odot dP_i - \left(dP_i^T P_i\right) \cdot P_i. [Note: while P_i is a row of P (and dP_i is a row of dP), it is still treated a column vector following the convention of linear algebra. So \left(dP_i^T P_i\right) is actually a scalar, an inner product, while P_i P_i^T is a matrix, an outer product, of the two vectors.]
  4. With S=QK^T and matrix calculus rules, we obtain dQ=dSK and dK=dS^T Q.

As we can see, the standard attention backward pass requires loading Q, K, V, dO, dS, P, dP and writing dS, dP, dQ, dV, and dK. However, in FlashAttention we do not store P. So in its backward pass, we also need to compute (blocks of) P on the fly. Moreover, dP_i^T P_i can be computed and prestored in HBM via dO_i and O_i:

(1)   \begin{align*} dP_i^T P_i &= (dO_i^T V^T) P_i \\ &=dO_i^T (V^T P_i) \\ &=dO_i^T (P_i^T V)^T \\ &=dO_i^T O_i \end{align*}

Now let’s summarize IO complexity (i.e., HBM accesses) for standard attention and FlashAttention.

  1. Forward pass:
    1. standard attention: O(Nd+N^2) as discussed above
    2. FlashAttention: O(N^2 d^2 M^{-1}) because one full inner loop (starting from line 7 in Algorithm 1) needs to load the full Q from HBM (O(Nd)), and the outer loop (line 5 in Algorithm 1) needs to perform T_c times (O(NdM^{-1})). Another good reference of IO complexity analysis can be found in [4]
  2. Backward pass:
    1. standard attention: O(Nd+N^2)
    2. FlashAttention: O(N^2 d^2 M^{-1}) Please see analysis in Theorem 5 in [1]

In practice, when will FlashAttention outperform Attention? This is when d^2 can be greatly smaller than M^{-1}. First of all, we need to clarify that M should represent the L1 cache size of one streaming multiprocessor (SM) in a GPU. Taking A100 as an example, one SM has 192KB L1 cache. So FlashAttention will probably have advantage of IO complexity when d= 64, 128, or 256 but has no advantage when d \geq 512. Let’s also summarize the total memory footprint required by standard attention vs FlashAttention: 1. standard attention: it needs to store K, V, and O matrices, which needs O(Nd) space. However the most costly part comes from storing P and S, which takes O(N^2) space. Overall, the memory footprint is O(Nd+N^2). 2. FlashAttention: it still needs to store K, V, and O matrices, which needs O(Nd) space. But it does not store P and S; instead it maintains the running maximum values and running accumulative softmax denominators in the forward pass so that they can be used in the backward pass, which takes O(N) space. Overall, the memory footprint of FlashAttention is only O(Nd).    

Now let’s move on to the engineering test. I picked one simple online implementation of FlashAttention and compared it with pytorch scaled dot product attention implementation (with backend=memory efficient) and a naive implementation of attention. Note that FlashAttention can be thought as a combination of reducing memory as HBM accesses. The pytorch implementation with the memory efficient backend essentially only reduces the memory footprint but not HBM accesses. The comparing script is here (python benchmark_flash_triton.py).

The comparison result is shown below. The numbers are TFlops measured. We test both forward and backward passes under the setting: num_heads = 48 and num_dim = 64. When OOM arises, we mark the result as -1. We can see the FlashAttention version clearly has higher FLOPs and can handle much longer sequences than the other two versions.

As a technique to reduce memory footprint and HBM accesses, FlashAttention is useful in training and the pre-filling stage in inference, both requiring obtaining the attention outputs of given sequences. So FlashAttention can be helpful to reduce time-to-first-token (TTFT). However, for the actual decoding phase, we need more pecialized variants (like Flash-Decoding). We will cover it later.

Torch.compile / CudaGraph

Cuda Graphs is an NVIDIA hardware-level optimization designed to speed up communication delay between CPU and GPU because sometimes your computation speed may be bottlenecked by how fast CPU sends kernel launch commands to GPU. The Cuda Graphs technique captures the kernel launch sequences in a graph structure first and then in the replay phase the CPU can just send one single command to GPU to execute the graph, which consists of all recorded kernel launches. Torch.compile is a set of optimization techniques that can be automatically applied to python+pytorch functions. It records python code logics in computation graphs and then rely on TorchInductor to conduct optimizations:

  1. fuse operations/kernels
  2. auto-tune kernel configurations like block sizes
  3. Choose different backends for matmul and perform prologue and epilogue fusion (TODO: understand these fusions)
  4. use CUDA Graphs to cache and replay kernel launches efficiently

As you can see, torch.compile is a superset of optimization techniques like Cuda Graphs. Below, we show a very basic example of how torch.compile accelerates a function. We test two modes of torch.compile, “max-autotune” and “max-autotune-no-cudagraphs”. While max-autotune gives torch.compile the full autonomy to compile your function, sometimes it may over-engineer it. For example, applying cuda graphs may be overkill for small functions like the one in the example below.

import torch
import time
import torch._dynamo as dynamo

# Ensure you have a CUDA-enabled GPU for the best results
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# 1. Define a simple function with a loop
def my_function(x):
    for _ in range(10):
        # A series of simple operations
        x = torch.sin(x) + torch.cos(x)
    return x

# 2. Create some input data
input_tensor = torch.randn(1024, 1024, device=device)

# 3. Compile the function
# The 'mode' can be tuned for different trade-offs. 'max-autotune' is aggressive.
compiled_function = torch.compile(my_function, mode="max-autotune")

# 4. Compile the function without cuda graph
# The 'mode' can be tuned for different trade-offs. 'max-autotune' is aggressive.
compiled_function_no_cuda_graph = torch.compile(my_function, mode="max-autotune-no-cudagraphs")

# --- Benchmarking ---

# WARM-UP RUNS: The first run has a one-time compilation overhead.
# We run both once to get this out of the way for fair timing.
print("\nStarting warm-up runs...")
_ = my_function(input_tensor)
_ = compiled_function(input_tensor)
_ = compiled_function_no_cuda_graph(input_tensor)
if device == 'cuda':
    torch.cuda.synchronize()
print("Warm-up complete.")


# Time the original EAGER mode function
print("\nBenchmarking Eager Mode...")
start_time = time.perf_counter()
for _ in range(100):
    _ = my_function(input_tensor)
# Make sure all GPU operations are finished before stopping the timer
if device == 'cuda':
    torch.cuda.synchronize()
eager_time = time.perf_counter() - start_time


# Time the COMPILED function
print("Benchmarking Compiled Mode...")
start_time = time.perf_counter()
for _ in range(100):
    _ = compiled_function(input_tensor)
if device == 'cuda':
    torch.cuda.synchronize()
compiled_time = time.perf_counter() - start_time


# Time the COMPILED function with no cuda graph
print("Benchmarking Compiled Mode with no Cuda Graphs...")
start_time = time.perf_counter()
for _ in range(100):
    _ = compiled_function_no_cuda_graph(input_tensor)
if device == 'cuda':
    torch.cuda.synchronize()
compiled_time_no_cuda_graph = time.perf_counter() - start_time


# --- Results ---
print("\n--- Results ---")
print(f"Eager mode total time:    {eager_time:.4f} seconds")
print(f"Compiled mode total time: {compiled_time:.4f} seconds")
print(f"Compiled mode with no cuda graph total time: {compiled_time_no_cuda_graph:.4f} seconds")
print(f"\nSpeedup: {eager_time / compiled_time:.2f}x faster! 🎉")
print(f"Speedup with no cuda graph: {eager_time / compiled_time_no_cuda_graph:.2f}x faster! 🎉")


# Use explain to see how Dynamo will handle the function
# It will print out a detailed report
explanation = dynamo.explain(my_function)(input_tensor)
print("\n" + "="*50)
print("Dynamo Explanation Output Summary:")
print(explanation)

 
Result (consistent with the finding in [6] that enabling cuda graph actually hurts the performance for small functions):
Eager mode total time:    0.0157 seconds
Compiled mode total time: 0.0856 seconds
Compiled mode with no cuda graph total time: 0.0025 seconds

Speedup: 0.18x faster! 
Speedup with no cuda graph: 6.32x faster! 

 
 
      Stanford CS336 course:
https://stanford-cs336.github.io/spring2025/ https://www.youtube.com/playlist?list=PLoROMvodv4rOY23Y0BoGoBGgQ1zmU_MT_  
 
Learn Triton https://www.google.com/search?q=book+learn+triton  
 
Different parallelisms explained https://ailzhang.github.io/posts/distributed-compute-in-transformer/  
 

References

[1] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness: https://arxiv.org/abs/2205.14135
[2] [Attention优化][2w字]📚原理篇: 从Online-Softmax到FlashAttention V1/V2/V3: https://zhuanlan.zhihu.com/p/668888063
[4] FlashAttention 的速度优化原理是怎样的? – Civ的回答 – 知乎 https://www.zhihu.com/question/611236756/answer/3132304304

Leave a comment

Your email address will not be published. Required fields are marked *