Lecture 6 — Kernels, Triton

5 min read

Overview: From Theory to Practice

This lecture transitions from the high-level architecture of GPUs (Lecture 5) to the practical, hands-on process of writing and optimizing code that runs on them. The central goal is to understand why some GPU code is fast and some is slow, and how to bridge that gap. The lecture introduces two essential diagnostic tools, benchmarking and profiling, and then dives into three ways of writing custom GPU operations ("kernels"): raw CUDA, the Python-based Triton, and the automated torch.compile. The unifying principle is that performance comes from minimizing data movement between slow global memory (DRAM) and fast on-chip memory (SRAM) through techniques like kernel fusion.


1. Understanding Performance: Benchmarking and Profiling

Before optimizing, one must measure. There are two indispensable tools for this:

  1. Benchmarking: Measures the total wall-clock time an operation takes to complete. It answers the question, "How long did it take?" It's essential for comparing the end-to-end performance of different implementations and understanding how an algorithm scales with input size. The process involves:

    • Warmup runs: Executing the code a few times initially to account for one-time costs like JIT compilation.
    • Multiple trials: Running the code multiple times to get an average and understand variance.
    • Synchronization: Using torch.cuda.synchronize() to ensure all GPU work is finished before stopping the timer.
  2. Profiling: Breaks down the total time from benchmarking to show where that time was spent. It answers, "What happened during that time?" The PyTorch profiler is a powerful tool that reveals:

    • Which specific GPU kernels were launched (e.g., cutlass_..._gemm for matrix multiplication).
    • How much time was spent on the CPU dispatching operations versus on the GPU executing them.
    • That different input shapes can trigger different, specialized kernels, explaining non-obvious performance variations.
    • The composition of complex PyTorch operations (e.g., torch.cdist is composed of matmul, pow, sum, etc.).

Key Insight: Profiling reveals that a sequence of simple PyTorch operations can launch many individual kernels, each requiring a costly round-trip to global memory.


2. Kernel Fusion: The Key to Performance

The most important optimization principle is kernel fusion. The lecture uses the "warehouse and factory" analogy:

  • Global Memory (DRAM): The large but slow warehouse.
  • On-Chip Memory (SRAM) & Compute Units: The small but extremely fast factory.
  • Memory Bandwidth: The slow conveyor belt between them.

A non-fused operation (e.g., y = a + b; z = y * c) is like sending materials from the warehouse to the factory, processing them, sending the intermediate product back to the warehouse, then fetching it again for the next step. A fused kernel performs the entire chain of operations (z = (a+b)*c) in one trip, keeping all data in the fast "factory" memory, dramatically reducing the bottleneck of memory bandwidth.

Case Study: Fused GeLU

The lecture demonstrates this with the GeLU activation function.

  • Manual (Non-Fused) GeLU: Implementing the GeLU formula with standard PyTorch operations (*, +, tanh, etc.) is slow because it launches multiple separate kernels.
  • PyTorch (Fused) GeLU: torch.nn.functional.gelu is much faster because it calls a single, pre-written, fused CUDA kernel that performs the entire computation in one go.

3. Writing Your Own Kernels: Three Levels of Abstraction

When a fused kernel doesn't exist or you need custom logic, you can write your own.

Level 1: Raw CUDA (in C++)

  • What it is: The lowest-level, most powerful way to program NVIDIA GPUs, using an extension of C++. The programmer has full control over threads, blocks, and memory management.
  • Execution Model: You write a function that is executed by a single thread. CUDA launches a grid of thread blocks, and each thread uses its unique ID (blockIdx, threadIdx) to determine which piece of data to work on.
  • Pros: Maximum control and potential for performance.
  • Cons: Very complex, verbose, and requires deep knowledge of the hardware. The lecture shows how to implement a simple GeLU kernel in CUDA.

Level 2: Triton (in Python)

  • What it is: A Python-based language and compiler developed by OpenAI to make GPU programming more accessible.
  • Key Abstraction: Instead of thinking about individual threads, you write code that operates on blocks of data. Triton's compiler then automatically handles the low-level details of generating efficient code, including:
    • Memory Coalescing: Ensuring that memory reads/writes are efficient.
    • Shared Memory Management: Using tiling to maximize data reuse.
  • Pros: Dramatically simpler than CUDA, written in Python, and often generates code that is as fast or faster than hand-tuned libraries.
  • Cons: Less fine-grained control than raw CUDA.

Level 3: torch.compile (Automated Fusion)

  • What it is: The highest-level approach. You write standard PyTorch code in Python, and torch.compile automatically analyzes the computation graph, fuses compatible operations together, and generates highly optimized Triton kernels in the background.
  • Pros: Easiest to use, requires no knowledge of CUDA or Triton, and can provide significant speedups for free.
  • Cons: The compiler may not be able to fuse all patterns optimally; sometimes manual kernel writing is still necessary for peak performance.

Conclusion on Kernels: The compiled and Triton versions of GeLU achieve performance very close to the native PyTorch implementation, and all are significantly faster than the non-fused manual version.


4. Advanced Case Study: Fused Softmax in Triton

The lecture concludes by demonstrating how to write a fused kernel for a more complex, non-elementwise operation: softmax.

  • Naive Implementation: A manual softmax involves multiple steps (find max per row, subtract, exponentiate, sum per row, divide), resulting in at least four separate passes over the data in global memory.
  • Fused Triton Implementation: A single Triton kernel can perform the entire operation in one pass. Each row of the input matrix is handled by a single thread block. The block:
    1. Loads the entire row into fast on-chip SRAM.
    2. Performs all the necessary computations (max, exp, sum, etc.) entirely within SRAM.
    3. Writes the final, normalized row back to global memory once.

This approach drastically reduces memory traffic and is significantly faster than the naive version, again illustrating the core principle of keeping computation and data as close as possible in the GPU's fast on-chip memory.

Copyright 2025, Ran DingPrivacyTerms