Links
- Lecture video: https://youtu.be/msHyYioAyNE
- Course materials: lecture_02.py
Overview and Mindset
This lecture dives into the fundamental building blocks of training a neural network using PyTorch. Going beyond just the code, the core theme is developing a mindset of resource accounting. To build and train large models efficiently, it's crucial to understand and quantify the two primary resources consumed at every step: memory and compute. The lecture provides the tools for "napkin math" to estimate training time and hardware requirements.
1. Resource Accounting: The Two Pillars
Pillar 1: Memory (GB)
Memory is required to store everything the model needs during a training step. The total memory footprint is a sum of several components:
- Model Parameters: The weights and biases of the model.
- Gradients: The gradients of the loss with respect to each parameter, required for updates. These are the same size as the parameters.
- Optimizer States: Optimizers like Adam maintain running averages (first and second moments) for each parameter, often doubling the memory required for parameters.
- Activations: The intermediate outputs of each layer, which are saved during the forward pass to be used in the backward pass. This component's size depends on the batch size and sequence length.
Floating Point Data Types: A Key Trade-off
The choice of data type is critical for managing memory:
float32
(fp32): The default. It uses 4 bytes per number, offering high precision and a wide dynamic range. It's safe but memory-intensive.float16
(fp16): Uses only 2 bytes, halving memory usage. However, it has a very limited dynamic range, making it prone to underflow (small numbers becoming zero), which can lead to training instability.bfloat16
(bf16): The modern standard for deep learning. It also uses 2 bytes but cleverly allocates its bits to have the same wide dynamic range asfloat32
at the cost of lower precision. This trade-off is excellent for deep learning, preventing underflow while still providing significant memory savings.fp8
: An even newer 1-byte format supported on the latest hardware (like H100s) for further acceleration, though it requires careful handling.
Pillar 2: Compute (FLOPs)
Compute is measured in FLOPs (Floating-Point Operations). The vast majority of computation in a Transformer is consumed by matrix multiplications (matmul
).
The Rule of Thumb for Training FLOPs
A simple and powerful approximation for the total FLOPs required for a training run is:
Total FLOPs ≈ 6 × Number of Parameters (N) × Dataset Size in Tokens (D)
This breaks down as follows:
- Forward Pass: Costs roughly 2 × N × D FLOPs. For a matrix multiplication
Y = XA
, whereX
hasD
tokens andA
hasN
parameters, the cost is2 * D * N
. - Backward Pass: Costs roughly 4 × N × D FLOPs (approximately twice the forward pass).
Model FLOPs Utilization (MFU)
This metric measures how efficiently you are using the hardware. It's the ratio of the actual FLOPs per second your code achieves versus the theoretical peak FLOPs per second of the GPU.
MFU = Actual FLOP/s / Promised FLOP/s
An MFU of 50% or higher is considered very good. Achieving high MFU depends on keeping the GPU's compute units busy, which is easiest when performing large matrix multiplications. Using lower precision like bfloat16
significantly increases the promised FLOP/s of the hardware.
2. The Training Pipeline in PyTorch
The lecture walks through building the entire training pipeline from the ground up.
Tensors: The Basic Building Block
- Creation & Memory: Tensors are multi-dimensional arrays. Their memory usage is
num_elements * bytes_per_element
. They are created on the CPU by default but must be moved to the GPU (.to('cuda')
) for fast computation. - Views vs. Copies: Many operations in PyTorch (like slicing,
.view()
,.transpose()
) are "views" that don't create new data in memory; they just change the metadata (stride) for accessing the original data. This is fast but means modifying the view modifies the original. Operations that require a contiguous block of memory (like a further.view()
on a transposed tensor) may require an explicit.contiguous()
call, which does create a copy and uses more memory and compute. einops
: A powerful library for tensor manipulation that uses named axes (e.g.,'batch seq hidden'
), making code more readable and less error-prone than using traditional dimension indices.
Models (nn.Module
)
- Structure: Models are built as classes inheriting from
torch.nn.Module
. Learnable parameters are defined usingnn.Parameter
. - Initialization: Proper weight initialization is critical for stable training. A standard technique is Xavier initialization, which scales the initial random weights by
1/sqrt(input_dim)
. This ensures that the variance of the outputs remains constant, preventing activations and gradients from exploding or vanishing.
Data Loading
- Process: The training data is a long sequence of token IDs. For efficiency with large datasets,
np.memmap
is used to load the data from disk lazily, only bringing the necessary chunks into RAM. - Batching: A
get_batch
function randomly samples starting positions in the dataset to create a batch of sequences ((batch_size, sequence_length)
). - Pinned Memory: Using
.pin_memory()
on a CPU tensor allows it to be copied to the GPU asynchronously (non_blocking=True
), which helps overlap data loading with GPU computation for better performance.
The Optimizer
- Role: An optimizer (like SGD or Adam) updates the model's parameters using the computed gradients.
- State: Modern optimizers like Adam are stateful. They store additional data for each parameter (e.g., momentum and variance moving averages), which can significantly increase memory usage (e.g., adding 8 bytes per parameter for Adam in fp32).
The Full Training Loop
The process for a single training step is:
- Zero Gradients: Clear old gradients (
optimizer.zero_grad()
). - Get Data: Load a batch of data.
- Forward Pass: Run the data through the model to get predictions and compute the loss.
- Backward Pass: Call
loss.backward()
to compute gradients for all parameters. - Optimizer Step: Call
optimizer.step()
to update the model parameters based on the gradients.
Essential Best Practices
- Checkpointing: Periodically save the model's
state_dict()
and the optimizer'sstate_dict()
to disk. Training runs are long and can crash; checkpointing ensures you don't lose all your progress. - Mixed Precision Training: To get the best of both worlds (speed of low precision, stability of high precision), use techniques like PyTorch's Automatic Mixed Precision (AMP). This automatically runs operations like
matmul
inbfloat16
for speed while keeping critical components like loss calculations and weight updates infloat32
for stability.