Links
- Lecture video: https://youtu.be/6OBtO9niT00
- Course materials: lecture 5.pdf
Overview: From Magic to Mechanics
This lecture demystifies the GPU, moving beyond its perception as a "magic box" for deep learning to explain the underlying hardware architecture and programming principles that enable its massive performance. The central theme is that compute power has scaled far more rapidly than memory bandwidth, creating a "memory wall." The primary goal of high-performance GPU programming is therefore to minimize data movement and maximize the utilization of the GPU's powerful compute units. The lecture explains how to achieve this through techniques like tiling, fusion, and recomputation, culminating in a deep dive into FlashAttention, an algorithm that masterfully applies these principles to make the core mechanism of Transformers memory-efficient.
Part 1: Anatomy of a GPU
A GPU is fundamentally different from a CPU. While a CPU is optimized for low latency (executing single threads as fast as possible), a GPU is optimized for high throughput (executing thousands of threads in parallel).
Core Architectural Components
- Streaming Multiprocessors (SMs): A GPU is composed of many SMs (e.g., an A100 has 108). Each SM is an independent core that can execute its own instructions.
- CUDA Cores / Streaming Processors (SPs): Within each SM are many smaller processing units (ALUs) that perform the actual calculations.
- Tensor Cores: Specialized hardware units introduced in the Volta architecture (and enhanced since) that are exclusively designed to accelerate matrix multiplications, particularly using mixed-precision (e.g., multiplying two
bfloat16
matrices and accumulating the result infloat32
). These cores provide a >10x speedup formatmul
operations compared to standard floating-point operations.
The GPU Memory Hierarchy
The key to GPU performance is understanding its memory hierarchy. Accessing memory is the biggest bottleneck, and latency increases dramatically as you move further from the compute units.
- Registers (Fastest): Per-thread, on-chip memory. Extremely fast, but very small.
- Shared Memory / L1 Cache (SRAM): Per-SM, on-chip memory. Much faster than global memory and can be shared between threads within the same block. This is the key to performance optimization.
- L2 Cache (SRAM): A larger on-chip cache shared by all SMs.
- Global Memory (DRAM/HBM): The main GPU memory (e.g., 80GB on an H100). It's large but has significantly higher latency and lower bandwidth than on-chip memory. All communication between different thread blocks must go through global memory.
The CUDA Execution Model
- Threads: The most basic unit of execution. All threads run the same code, a model known as SIMT (Single Instruction, Multiple Thread).
- Warps: Threads are executed in groups of 32 called a warp. All 32 threads in a warp execute the same instruction at the same time. If a conditional statement causes threads to take different paths (warp divergence), the paths are executed serially, killing parallelism and performance.
- Blocks: Threads are grouped into blocks. All threads within a block are executed on the same SM and can communicate via its fast shared memory.
- Grid: A grid is composed of all the blocks for a given operation (a "kernel").
Part 2: Principles of High-Performance GPU Programming
The massive gap between compute speed and memory speed is visualized by the Roofline Model. An operation is either memory-bound (limited by memory bandwidth) or compute-bound (limited by the GPU's peak FLOP/s). The goal is to increase operational intensity (the ratio of FLOPs to bytes of data moved) to push an operation into the compute-bound region.
The following techniques are used to achieve this:
- Operator Fusion: Combining multiple sequential, element-wise operations (e.g.,
LayerNorm -> Dropout -> Add
) into a single GPU kernel. This avoids writing intermediate results to slow global memory, keeping all data in fast registers and on-chip memory. Compilers like PyTorch'storch.compile
can do this automatically. - Recomputation (Activation Checkpointing): Instead of storing all activations during the forward pass (which consumes vast amounts of memory), this technique stores only a fraction of them. During the backward pass, the missing activations are recomputed on-the-fly. This trades compute (which is abundant) for memory (which is scarce).
- Memory Coalescing: Structuring memory access patterns so that all 32 threads in a warp access a contiguous block of global memory. This allows the hardware to satisfy the 32 individual requests with a single, wide memory transaction, effectively maximizing memory bandwidth.
- Tiling (The Most Important Technique): This is the core strategy for making memory-bound operations like matrix multiplication compute-bound. The large input matrices in global memory are broken down into smaller "tiles" that can fit entirely within an SM's fast shared memory (SRAM). The computation is then performed tile-by-tile. This maximizes data reuse from fast memory, dramatically reducing the number of slow reads/writes to global memory.
Part 3: Case Study - FlashAttention
FlashAttention is a perfect application of these principles to solve a major bottleneck in the Transformer architecture.
-
The Problem with Standard Attention: The standard attention implementation is memory-bound. It requires the materialization of a large
N x N
attention score matrix in slow global memory (HBM). For a sequence lengthN=4096
, this matrix is4k x 4k
, and the memory traffic to read and write it dominates the runtime. -
The FlashAttention Solution:
- Tiling: FlashAttention breaks the input Q, K, and V matrices into smaller blocks, or tiles. It loads a block of Q and a block of K into the fast on-chip SRAM of an SM.
- Fusion: It performs the matrix multiplication, masking, and softmax operations entirely within SRAM for that tile, without ever writing the full
N x N
attention matrix to global memory. A key innovation is the use of an "online" softmax algorithm that allows the softmax to be computed correctly in a streaming, block-by-block fashion. - Recomputation: For the backward pass, instead of reading the saved
N x N
attention matrix (which was never stored), it recomputes the necessary parts of it on-the-fly from the Q, K, and V blocks loaded into SRAM.
By applying tiling, fusion, and recomputation, FlashAttention transforms the attention mechanism from a memory-bound operation into a compute-bound one, resulting in significant speedups (e.g., 7.6x faster runtime on GPT-2) and a massive reduction in memory usage. It is an exact attention algorithm, not an approximation, and has become a standard component in virtually all modern LLM training and inference frameworks.