GPU MODE Lecture 5: Going Further with CUDA for Python Programmers
- GPU MODE Lecture Notes: My notes from the GPU MODE reading group lectures run by Andreas Kopf and Mark Saroufim.
- Introduction and Overview
- Resources and Setup
- Matrix Multiplication Example
- Optimizing with Shared Memory
- Implementing Tiling with Numba
- Q&A Session
- YouTube Recording: Lecture 5: Going Further with CUDA for Python Programmers
- Jupyter Notebook: lecture_005/matmul_l5.ipynb
- utils.py: utils.py
Introduction and Overview
- Going Further with CUDA for Python Programmers: This lecture builds upon the foundational knowledge presented in “Getting Started with CUDA for Python Programmers” and focuses on optimizing CUDA code for performance by leveraging fast memory.
- Prerequisites: Familiarity with basic CUDA concepts and Python programming, including thread utilization.
- Recommended Resources:
- “Programming Massively Parallel Processes” (book), Chapter 5.
- CUDA Mode lecture by Thomas Viehmann (covers Chapter 4 & 5).
- Lecture Focus: Utilizing shared memory, a faster memory type within the GPU, to improve performance.
- Memory Hierarchy:
- Global Memory: Default memory used in CUDA, relatively fast but not the fastest.
- Accessed by all threads.
- (e.g., with
tensor.cuda()
in PyTorch)
- Shared Memory: Significantly faster than global memory (about 10x).
- Accessible only by threads within a specific block (on a streaming multiprocessor).
- Global Memory: Default memory used in CUDA, relatively fast but not the fastest.
- Importance of Memory Access Speed: Due to the high processing speed of GPUs, memory access becomes a performance bottleneck. Utilizing shared memory effectively is crucial for optimization.
Resources and Setup
Repository: CUDA Mode lectures repository, specifically lecture 5 notebook.
- GitHub Repository: https://github.com/cuda-mode/lectures
utils.py: Contains helper functions (e.g., ceiling division, CUDA code loading, prefix for CUDA code).
dim3
: Python namedtuple representing a 3D grid (x, y, z) for blocks and threads, mirroring CUDA’s Dim3 structure.Debugging Tools: Wurlitzer for printing from CUDA kernels, CUDA launch blocking for debugging.
Setup Code:
import os # Operating system interfaces import math # Mathematical functions import sys # System-specific parameters and functions import torch # PyTorch library for tensor computations and neural networks import re # Regular expression operations import numpy as np # NumPy library for numerical computations from types import SimpleNamespace as ns # Allows creation of attribute-accessible objects from collections import namedtuple # Factory function for creating tuple subclasses with named fields
# Define a custom 3D dimension namedtuple with default values = namedtuple('dim3', ['x', 'y', 'z'], defaults=(1, 1)) dim3
# Create a 2D dimension instance = dim3(2, 3) d # Display the full dimension object d
dim3(x=2, y=3, z=1)
# Display x and y components of the dimension d.x, d.y
(2, 3)
# Configure NumPy print options for cleaner output =2, linewidth=140) np.set_printoptions(precision # Configure PyTorch print options for cleaner output and disable scientific notation =2, linewidth=140, sci_mode=False) torch.set_printoptions(precision
# Import utility functions from utils import show_img, load_cuda, cuda_begin, cdiv
# Load the wurlitzer IPython extension for capturing C-level output %load_ext wurlitzer
# Set a random seed for reproducibility 42) torch.manual_seed(
<torch._C.Generator at 0x728ffff23630>
Matrix Multiplication Example
Problem: Multiplying a 5120x256 matrix (M1) by a 256x5120 matrix (M2).
# Create a large random tensor (5120x256) = torch.rand(5120, 256) m1 # Extract the first 4 rows of m1 = m1[:4] m1s # Create another large random tensor (256x5120) = torch.rand(256, 5120) m2 # Extract the first 4 columns of m2 = m2[:, :4] m2s
Previous Approaches (Recap)
Naive Matrix Multiplication Kernel:
- Calculates dot product for each element in the output matrix.
- Accesses global memory repeatedly within the inner loop, leading to performance issues.
Pure Python Baseline: Extremely slow, uses a small sample of the matrices (4x4) for demonstration.
def blk_kernel2d(f, blocks, threads, *args): """ Simulate a 2D GPU kernel execution on CPU. This function emulates the behavior of a 2D GPU kernel by iterating over blocks and threads in a nested loop structure. Args: f (function): The kernel function to be executed. blocks (dim3): The number of blocks in x and y dimensions. threads (dim3): The number of threads per block in x and y dimensions. *args: Additional arguments to be passed to the kernel function. Returns: None """ for i0 in range(blocks.y): for i1 in range(blocks.x): for j0 in range(threads.y): for j1 in range(threads.x): # Execute the kernel function for each thread *args) f(dim3(i1,i0), dim3(j1,j0), threads,
def matmul_bk(blockIdx, threadIdx, blockDim, m, n, out, h, w, k): """ Perform matrix multiplication for a single element in the output matrix. This function calculates one element of the output matrix by multiplying a row from the first matrix with a column from the second matrix. Args: blockIdx (dim3): The current block index. threadIdx (dim3): The current thread index within the block. blockDim (dim3): The dimensions of the block. m (Tensor): Flattened first input matrix. n (Tensor): Flattened second input matrix. out (Tensor): Flattened output matrix. h (int): Height of the output matrix. w (int): Width of the output matrix. k (int): Common dimension of input matrices. Returns: None """ # Calculate global thread indices = blockIdx.y * blockDim.y + threadIdx.y r = blockIdx.x * blockDim.x + threadIdx.x c # Check if the thread is within the output matrix dimensions if (r >= h or c >= w): return # Perform dot product of row from m and column from n = 0. o for i in range(k): += m[r*k+i] * n[i*w+c] o # Store the result in the output matrix *w+c] = o out[r
def matmul_2d(m, n): """ Perform matrix multiplication using a simulated 2D GPU kernel. This function sets up the execution configuration and launches the matrix multiplication kernel. Args: m (Tensor): First input matrix. n (Tensor): Second input matrix. Returns: Tensor: Result of matrix multiplication. Raises: AssertionError: If the inner dimensions of input matrices don't match. """ = m.shape h, k = n.shape k2, w assert k == k2, "Size mismatch!" # Initialize output matrix = torch.zeros(h, w, dtype=m.dtype) output # Set up thread and block dimensions = dim3(16, 16) # Threads per block tpb = dim3(cdiv(w, tpb.x), cdiv(h, tpb.y)) # Number of blocks blocks # Launch the kernel blk_kernel2d(matmul_bk, blocks, tpb, m.flatten(), n.flatten(), output.flatten(), h, w, k) return output
# Verify the result by comparing with PyTorch's built-in matrix multiplication @m2s).all() torch.isclose(matmul_2d(m1s, m2s), m1s
tensor(True)
- Simple Kernel Runner: Iterates through simulated blocks and threads, calling a kernel function (not a real CUDA kernel).
CUDA Kernel Runner: Similar to the simple kernel runner but uses CUDA’s syntax for launching kernels (triple angle brackets).
# CUDA kernel definition and PyTorch C++ extension implementation = cuda_begin + r''' cuda_src __global__ void matmul_k(float* m, float* n, float* out, int h, int w, int k) { // Calculate global thread indices int r = blockIdx.y*blockDim.y + threadIdx.y; int c = blockIdx.x*blockDim.x + threadIdx.x; // Check if thread is within matrix bounds if (r >= h || c >= w) return; // Perform dot product for this element float o = 0; for (int i = 0; i < k; ++i) o += m[r*k+i] * n[i*w+c]; out[r*w+c] = o; } torch::Tensor matmul(torch::Tensor m, torch::Tensor n) { CHECK_INPUT(m); CHECK_INPUT(n); int h = m.size(0); int w = n.size(1); int k = m.size(1); TORCH_CHECK(k==n.size(0), "Size mismatch!"); auto output = torch::zeros({h, w}, m.options()); // Define thread block and grid dimensions dim3 tpb(16,16); dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y)); // Launch CUDA kernel matmul_k<<<blocks, tpb>>>( m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k); C10_CUDA_KERNEL_LAUNCH_CHECK(); return output; } '''
= 'matmul' fname
def get_sig(fname, src): """ Extract the function signature from the source code. Args: fname (str): The name of the function to extract. src (str): The source code to search. Returns: str: The function signature with a semicolon appended, or None if not found. """ = re.findall(rf'^(.+\s+{fname}\(.*?\))\s*{{?\s*$', src, re.MULTILINE) res return res[0]+';' if res else None
= get_sig(fname, cuda_src) cpp_src cpp_src
'torch::Tensor matmul(torch::Tensor m, torch::Tensor n);'
# Load the CUDA module = load_cuda(cuda_src, cpp_src, [fname]) module
# Move tensors to GPU and ensure they are contiguous = m1.contiguous().cuda(), m2.contiguous().cuda() m1c, m2c
# Check the shape of the output module.matmul(m1c, m2c).shape
torch.Size([5120, 5120])
# Verify correctness by comparing with PyTorch's built-in matrix multiplication @m2c).all() torch.isclose(module.matmul(m1c, m2c), m1c
tensor(True, device='cuda:0')
- CUDA Kernel (Naive): ChatGPT-generated CUDA code based on the naive Python kernel.
Performance: CUDA version is significantly faster than pure Python.
%%timeit -n 10 # Benchmark the custom CUDA matmul implementation module.matmul(m1c, m2c) torch.cuda.synchronize()
3 ms ± 177 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Implementing Tiling with Numba
Numba: An alternative library for writing CUDA code directly in Python.
pip install numba pip install -U "numpy<2.1"
from numba import cuda from numba.cuda import as_cuda_array as ca
CUDA Kernel (Numba): Python code decorated with
@cuda.jit
to indicate it’s a CUDA kernel.@cuda.jit def matmul_k_numba(m, n, out, tw): """ Perform matrix multiplication on GPU using CUDA. This kernel function multiplies matrices 'm' and 'n', storing the result in 'out'. It uses shared memory and tiling for improved performance. Args: m (ndarray): First input matrix n (ndarray): Second input matrix out (ndarray): Output matrix to store the result tw (int): Tile width for shared memory optimization Note: This function is designed to be called from a host function, not directly. """ # Get CUDA thread and block information = cuda.blockIdx, cuda.blockDim, cuda.threadIdx cbi, cbd, tid = tid.x, tid.y tc, tr # Calculate global row and column indices = cbi.y * cbd.y + tr, cbi.x * cbd.x + tc r, c # Get input matrix dimensions = m.shape h, k = n.shape k2, w # Allocate shared memory for tile-based computation = cuda.shared.array(0, dtype=np.float32) shar = shar[:tw*tw], shar[tw*tw:2*tw*tw] # Split shared memory for both input matrices ms, ns # Initialize partial sum = np.float32(0.0) p # Iterate over tiles for ph in range(math.ceil(k/tw)): = ph * tw idx # Load data into shared memory, with boundary checks *tw+tc] = m[r, tc+idx] if r < h and idx+tc < k else 0. ms[tr*tw+tc] = n[tr+idx, c] if c < w and idx+tr < k else 0. ns[tr # Ensure all threads have loaded data cuda.syncthreads() # Compute partial dot product for this tile for i in range(tw): += ms[tr*tw+i] * ns[i*tw+tc] p # Ensure all threads have used the data before next iteration cuda.syncthreads() # Store the result if within output matrix bounds if r < h and c < w: = p out[r, c]
- Shared Memory:
cuda.shared.array
creates dynamic shared memory arrays. - Synchronization:
cuda.syncthreads()
for thread synchronization.
- Shared Memory:
Kernel Launching: Uses square brackets instead of triple angle brackets (e.g.,
kernel[blocks, threadsperblock, stream, shared_mem_size](...)
).def matmul_2d_numba(m, n, tw=16): """ Perform matrix multiplication using CUDA. This function prepares the CUDA kernel call for matrix multiplication. Args: m (Tensor): First input matrix (PyTorch tensor on CUDA) n (Tensor): Second input matrix (PyTorch tensor on CUDA) tw (int): Tile width for shared memory optimization (default: 16) Returns: Tensor: Result of matrix multiplication Raises: AssertionError: If input matrices have mismatched inner dimensions """ = m.shape h, k = n.shape k2, w assert k == k2, "Size mismatch!" # Initialize output matrix = torch.zeros(h, w, dtype=m.dtype, device=m.device) out # Set up CUDA kernel parameters = 2 * tw * tw * 4 # Size of shared memory in bytes dyn_shared_mem_size = tw, tw # Threads per block tpb = cdiv(w, tpb[0]), cdiv(h, tpb[1]) # Calculate grid dimensions blocks # Launch CUDA kernel 0, dyn_shared_mem_size](ca(m), ca(n), ca(out), tw) matmul_k_numba[blocks, tpb, return out
# Verify correctness of the implementation @m2c).all() torch.isclose(matmul_2d_numba(m1c, m2c), m1c
tensor(True, device='cuda:0')
Performance: The Numba version with dynamic shared memory is slower than the optimized CUDA C version but still provides CUDA-level speed.
%%timeit -n 10 # Benchmark the implementation matmul_2d_numba(m1c, m2c)# Ensure all CUDA operations are completed before timing torch.cuda.synchronize()
7.8 ms ± 80.7 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Benefits:
- Faster compilation times compared to PyTorch’s CUDA C/C++ approach.
- Allows for faster iteration during development.
- No need to flatten tensors (supports multidimensional indexing).
- Access to tensor shape information within the kernel.
- Faster compilation times compared to PyTorch’s CUDA C/C++ approach.
CUDA Simulator: Numba provides a built-in CUDA simulator by setting the environment variable
NUMBA_ENABLE_CUDASIM=1
.- Executes CUDA code as pure Python on the CPU, allowing for debugging and experimentation with small datasets.
Development Workflow:
- Develop and debug CUDA kernels in Numba with the simulator enabled.
- Disable the simulator to run the code on the GPU.
- Optionally, convert the Numba code to CUDA C/C++ using ChatGPT for deployment.
Q&A Session
- Shipping Numba Kernels and AOT Compilation:
- AOT Compilation: Numba’s AOT was discussed as a potential deployment simplification solution.
- AOT Deprecation: Numba’s AOT is deprecated (February 2024), with a replacement planned but unspecified.
- Performance Comparisons and Optimization Opportunities:
- Optimization Tools: TVM and Mojo GPU’s auto-tune (expected late February/March 2024) were mentioned as potential optimization aids.
- PyTorch’s Matrix Multiplication Implementation:
- PyTorch primarily uses cuBLAS.
- Torch Compile and Inductor: Torch Compile’s experimental mode (torch.inductor.config) was mentioned as a potential alternative backend.
- Profiling for Backend Identification: PyTorch’s profiler can reveal the backend used through function signatures.
- Compilation Speed and Iterative Development:
- Compilation Speed Importance: Fast compilation was emphasized as crucial for iterative development.
- Fast Compilation Benefits: Fast compilation, aided by tools like the CUDA simulator and Numba’s CUDA JIT, enhances productivity and reduces debugging time.
- ChatGPT’s Role in CUDA Development:
- ChatGPT’s Code Generation Capabilities: ChatGPT is useful for code conversion and API usage but less effective for novel algorithms.
- Numba vs. Triton:
- Different Purposes: Numba and Triton were recognized as valuable tools with distinct strengths, suitable for different use cases. Triton’s limitations in expressing certain CUDA constructs (e.g., 4-bit discretization) were noted.
- Complementary Tools: Numba and Triton were seen as complementary, each offering unique advantages.
I’m Christian Mills, a deep learning consultant specializing in practical AI implementations. I help clients leverage cutting-edge AI technologies to solve real-world problems.
Interested in working together? Fill out my Quick AI Project Assessment form or learn more about me.