CUDA MODE Lecture 5: Going Further with CUDA for Python Programmers

notes
cuda
Lecture #5 explores how to optimize matrix multiplication in CUDA for Python programmers using shared memory and tiling, comparing implementations in pure Python, CUDA C, and the Numba library.
Author

Christian Mills

Published

September 1, 2024

This post is part of the following series:
  • CUDA Mode Lecture Notes: My notes from the CUDA MODE reading group lectures run by Andreas Kopf and Mark Saroufim.
Resource Links:

Introduction and Overview

  • Going Further with CUDA for Python Programmers: This lecture builds upon the foundational knowledge presented in “Getting Started with CUDA for Python Programmers” and focuses on optimizing CUDA code for performance by leveraging fast memory.
  • Prerequisites: Familiarity with basic CUDA concepts and Python programming, including thread utilization.
  • Recommended Resources:
  • Lecture Focus: Utilizing shared memory, a faster memory type within the GPU, to improve performance.
  • Memory Hierarchy:
    • Global Memory: Default memory used in CUDA, relatively fast but not the fastest.
      • Accessed by all threads.
      • (e.g., with tensor.cuda() in PyTorch)
    • Shared Memory: Significantly faster than global memory (about 10x).
      • Accessible only by threads within a specific block (on a streaming multiprocessor).
  • Importance of Memory Access Speed: Due to the high processing speed of GPUs, memory access becomes a performance bottleneck. Utilizing shared memory effectively is crucial for optimization.

Resources and Setup

  • Repository: CUDA Mode lectures repository, specifically lecture 5 notebook.

  • utils.py: Contains helper functions (e.g., ceiling division, CUDA code loading, prefix for CUDA code).

  • dim3: Python namedtuple representing a 3D grid (x, y, z) for blocks and threads, mirroring CUDA’s Dim3 structure.

  • Debugging Tools: Wurlitzer for printing from CUDA kernels, CUDA launch blocking for debugging.

  • Setup Code:

    import os      # Operating system interfaces
    import math    # Mathematical functions
    import sys     # System-specific parameters and functions
    import torch   # PyTorch library for tensor computations and neural networks
    import re      # Regular expression operations
    import numpy as np  # NumPy library for numerical computations
    
    from types import SimpleNamespace as ns  # Allows creation of attribute-accessible objects
    from collections import namedtuple  # Factory function for creating tuple subclasses with named fields
    # Define a custom 3D dimension namedtuple with default values
    dim3 = namedtuple('dim3', ['x', 'y', 'z'], defaults=(1, 1))
    # Create a 2D dimension instance
    d = dim3(2, 3)
    
    # Display the full dimension object
    d
    dim3(x=2, y=3, z=1)
    # Display x and y components of the dimension
    d.x, d.y
    (2, 3)
    # Configure NumPy print options for cleaner output
    np.set_printoptions(precision=2, linewidth=140)
    
    # Configure PyTorch print options for cleaner output and disable scientific notation
    torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
    # Import utility functions
    from utils import show_img, load_cuda, cuda_begin, cdiv
    # Load the wurlitzer IPython extension for capturing C-level output
    %load_ext wurlitzer
    # Set a random seed for reproducibility
    torch.manual_seed(42)
    <torch._C.Generator at 0x728ffff23630>

Matrix Multiplication Example

  • Problem: Multiplying a 5120x256 matrix (M1) by a 256x5120 matrix (M2).

    # Create a large random tensor (5120x256)
    m1 = torch.rand(5120, 256)
    
    # Extract the first 4 rows of m1
    m1s = m1[:4]
    
    # Create another large random tensor (256x5120)
    m2 = torch.rand(256, 5120)
    
    # Extract the first 4 columns of m2
    m2s = m2[:, :4]

Previous Approaches (Recap)

  • Naive Matrix Multiplication Kernel:

    • Calculates dot product for each element in the output matrix.
    • Accesses global memory repeatedly within the inner loop, leading to performance issues.
  • Pure Python Baseline: Extremely slow, uses a small sample of the matrices (4x4) for demonstration.

    def blk_kernel2d(f, blocks, threads, *args):
        """
        Simulate a 2D GPU kernel execution on CPU.
    
        This function emulates the behavior of a 2D GPU kernel by iterating over
        blocks and threads in a nested loop structure.
    
        Args:
            f (function): The kernel function to be executed.
            blocks (dim3): The number of blocks in x and y dimensions.
            threads (dim3): The number of threads per block in x and y dimensions.
            *args: Additional arguments to be passed to the kernel function.
    
        Returns:
            None
        """
        for i0 in range(blocks.y):
            for i1 in range(blocks.x):
                for j0 in range(threads.y):
                    for j1 in range(threads.x):
                        # Execute the kernel function for each thread
                        f(dim3(i1,i0), dim3(j1,j0), threads, *args)
    def matmul_bk(blockIdx, threadIdx, blockDim, m, n, out, h, w, k):
        """
        Perform matrix multiplication for a single element in the output matrix.
    
        This function calculates one element of the output matrix by multiplying
        a row from the first matrix with a column from the second matrix.
    
        Args:
            blockIdx (dim3): The current block index.
            threadIdx (dim3): The current thread index within the block.
            blockDim (dim3): The dimensions of the block.
            m (Tensor): Flattened first input matrix.
            n (Tensor): Flattened second input matrix.
            out (Tensor): Flattened output matrix.
            h (int): Height of the output matrix.
            w (int): Width of the output matrix.
            k (int): Common dimension of input matrices.
    
        Returns:
            None
        """
        # Calculate global thread indices
        r = blockIdx.y * blockDim.y + threadIdx.y
        c = blockIdx.x * blockDim.x + threadIdx.x
    
        # Check if the thread is within the output matrix dimensions
        if (r >= h or c >= w):
            return
    
        # Perform dot product of row from m and column from n
        o = 0.
        for i in range(k):
            o += m[r*k+i] * n[i*w+c]
    
        # Store the result in the output matrix
        out[r*w+c] = o
    def matmul_2d(m, n):
        """
        Perform matrix multiplication using a simulated 2D GPU kernel.
    
        This function sets up the execution configuration and launches the
        matrix multiplication kernel.
    
        Args:
            m (Tensor): First input matrix.
            n (Tensor): Second input matrix.
    
        Returns:
            Tensor: Result of matrix multiplication.
    
        Raises:
            AssertionError: If the inner dimensions of input matrices don't match.
        """
        h, k = m.shape
        k2, w = n.shape
        assert k == k2, "Size mismatch!"
    
        # Initialize output matrix
        output = torch.zeros(h, w, dtype=m.dtype)
    
        # Set up thread and block dimensions
        tpb = dim3(16, 16)  # Threads per block
        blocks = dim3(cdiv(w, tpb.x), cdiv(h, tpb.y))  # Number of blocks
    
        # Launch the kernel
        blk_kernel2d(matmul_bk, blocks, tpb,
                     m.flatten(), n.flatten(), output.flatten(), h, w, k)
    
        return output
    # Verify the result by comparing with PyTorch's built-in matrix multiplication
    torch.isclose(matmul_2d(m1s, m2s), m1s@m2s).all()
    tensor(True)
    • Simple Kernel Runner: Iterates through simulated blocks and threads, calling a kernel function (not a real CUDA kernel).
  • CUDA Kernel Runner: Similar to the simple kernel runner but uses CUDA’s syntax for launching kernels (triple angle brackets).

    # CUDA kernel definition and PyTorch C++ extension implementation
    cuda_src = cuda_begin + r'''
    __global__ void matmul_k(float* m, float* n, float* out, int h, int w, int k) {
        // Calculate global thread indices
        int r = blockIdx.y*blockDim.y + threadIdx.y;
        int c = blockIdx.x*blockDim.x + threadIdx.x;
    
        // Check if thread is within matrix bounds
        if (r >= h || c >= w) return;
    
        // Perform dot product for this element
        float o = 0;
        for (int i = 0; i < k; ++i) o += m[r*k+i] * n[i*w+c];
        out[r*w+c] = o;
    }
    
    torch::Tensor matmul(torch::Tensor m, torch::Tensor n) {
        CHECK_INPUT(m); CHECK_INPUT(n);
        int h = m.size(0);
        int w = n.size(1);
        int k = m.size(1);
        TORCH_CHECK(k==n.size(0), "Size mismatch!");
        auto output = torch::zeros({h, w}, m.options());
    
        // Define thread block and grid dimensions
        dim3 tpb(16,16);
        dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    
        // Launch CUDA kernel
        matmul_k<<<blocks, tpb>>>(
            m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k);
        C10_CUDA_KERNEL_LAUNCH_CHECK();
        return output;
    }
    '''
    fname = 'matmul'
    def get_sig(fname, src):
        """
        Extract the function signature from the source code.
    
        Args:
            fname (str): The name of the function to extract.
            src (str): The source code to search.
    
        Returns:
            str: The function signature with a semicolon appended, or None if not found.
        """
        res = re.findall(rf'^(.+\s+{fname}\(.*?\))\s*{{?\s*$', src, re.MULTILINE)
        return res[0]+';' if res else None
    cpp_src = get_sig(fname, cuda_src)
    cpp_src
    'torch::Tensor matmul(torch::Tensor m, torch::Tensor n);'
    # Load the CUDA module
    module = load_cuda(cuda_src, cpp_src, [fname])
    # Move tensors to GPU and ensure they are contiguous
    m1c, m2c = m1.contiguous().cuda(), m2.contiguous().cuda()
    # Check the shape of the output
    module.matmul(m1c, m2c).shape
    torch.Size([5120, 5120])
    # Verify correctness by comparing with PyTorch's built-in matrix multiplication
    torch.isclose(module.matmul(m1c, m2c), m1c@m2c).all()
    tensor(True, device='cuda:0')
    • CUDA Kernel (Naive): ChatGPT-generated CUDA code based on the naive Python kernel.
  • Performance: CUDA version is significantly faster than pure Python.

    %%timeit -n 10
    # Benchmark the custom CUDA matmul implementation
    module.matmul(m1c, m2c)
    torch.cuda.synchronize()
    3 ms ± 177 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Optimizing with Shared Memory

Tiling

  • Problem: Repeated global memory access in the inner loop of the matrix multiplication kernel.
  • Solution: Tiling – dividing the matrices into smaller tiles and performing the multiplication tile-by-tile.
  • Tile Width (TW): The dimension of a square tile (e.g., 16x16).
  • Process:
    1. Load a tile from m1 and a tile from m2 into shared memory.
    2. Calculate the partial dot products for all elements within the output tile using the shared memory tiles.
    3. Repeat for all tiles, accumulating the partial dot products to get the final result.
  • Benefits:
    • Each input element is read from global memory only once.
    • Dot products are calculated using much faster shared memory.

Implementing Tiling in Python

  • Dynamic Shared Memory Simulation: Using NumPy or PyTorch tensor views to simulate dynamic shared memory allocation in CUDA.

  • Shared Memory Kernel Runner:

    def blk_kernel2d_shar(f, blocks, threads, sh_sz, *args, **kwargs):
        """
        Execute a 2D block kernel with shared memory.
    
        Args:
            f (function): The kernel function to execute
            blocks (dim3): Number of blocks in x and y dimensions
            threads (dim3): Number of threads per block
            sh_sz (int): Size of shared memory
            *args: Additional positional arguments for the kernel function
            **kwargs: Additional keyword arguments for the kernel function
        """
        for i0 in range(blocks.y):
            for i1 in range(blocks.x):
                shared = torch.zeros(sh_sz)
                f(dim3(i1, i0), threads, shared, *args, **kwargs)
    • Iterates through blocks.
    • Creates a simulated shared memory array.
    • Calls the kernel function, passing the shared memory.
  • Tiled Matrix Multiplication Kernel (Python):

    def matmul_tiled_bk(blockIdx, blockDim, shared, m, n, out, h, w, k, tw):
        """
        Perform tiled matrix multiplication using block-wise computation.
    
        Args:
            blockIdx (dim3): Current block index
            blockDim (dim3): Block dimensions
            shared (Tensor): Shared memory tensor
            m (Tensor): First input matrix (flattened)
            n (Tensor): Second input matrix (flattened)
            out (Tensor): Output matrix (flattened)
            h (int): Height of the first matrix
            w (int): Width of the second matrix
            k (int): Shared dimension of the two matrices
            tw (int): Tile width
        """
        shar_sz = tw * tw
        ms, ns = shared[:shar_sz], shared[shar_sz:]  # Split shared memory for both matrices
    
        for ph in range(cdiv(k, tw)):
            idx = ph * tw
            # Fill shared memory with tiles from input matrices
            for tr in range(blockDim.y):
                for tc in range(blockDim.x):
                    r, c = blockIdx.y * blockDim.y + tr, blockIdx.x * blockDim.x + tc
                    ms[tr*tw+tc] = m[tc+idx + r*k] if r < h and idx+tc < k else 0.
                    ns[tr*tw+tc] = n[(tr+idx)*w + c] if c < w and idx+tr < k else 0.
    
            # Compute dot products using shared memory
            for tr in range(blockDim.y):
                for tc in range(blockDim.x):
                    r, c = blockIdx.y * blockDim.y + tr, blockIdx.x * blockDim.x + tc
                    for i in range(tw):
                        if r*w+c < len(out):
                            out[r*w+c] += ms[tr*tw+i] * ns[tw*i+tc]
    • Fill Shared Memory:
      • Loops through each tile.
      • Calculates the starting index (idx) of the tile in the original matrix.
      • Loops through threads within the tile.
      • Calculates the row (r) and column (c) in the original matrix based on the tile index and thread index.
      • Copies the corresponding elements from the input matrices to the shared memory tiles (ms, ns).
      • Padding: Fills elements outside the matrix boundaries with zeros.
    • Dot Product from Shared Memory:
      • Loops through threads within the tile.
      • Calculates the row and column in the output matrix.
      • Performs the dot product using elements from the shared memory tiles.
    def matmul_2d(m, n, tw=16):
        """
        Perform 2D matrix multiplication using tiled block-wise computation.
    
        Args:
            m (Tensor): First input matrix
            n (Tensor): Second input matrix
            tw (int, optional): Tile width. Defaults to 16.
    
        Returns:
            Tensor: Result of matrix multiplication
        """
        h, k = m.shape
        k2, w = n.shape
        assert k == k2, "Size mismatch!"
    
        output = torch.zeros(h, w, dtype=m.dtype)
        tpb = dim3(tw, tw)  # Threads per block
        blocks = dim3(cdiv(w, tpb.x), cdiv(h, tpb.y))  # Number of blocks
    
        blk_kernel2d_shar(matmul_tiled_bk, blocks, tpb, tw*tw*2,
                          m.flatten(), n.flatten(), output.flatten(),
                          h, w, k, tw=tw)
        return output
    # Initialize a tensor 'a' with 5 zeros
    a = torch.zeros(5)
    
    # Split 'a' into two parts: 'b' (first 3 elements) and 'c' (last 2 elements)
    b, c = a[:3], a[3:]
    # Modify specific elements in 'b' and 'c'
    b[1] = 2
    c[0] = 6
    # The value of 'a' is now implicitly modified due to tensor slicing
    a
    tensor([0., 2., 0., 6., 0.])
    # Check shapes of matrices m1s and m2s
    m1s.shape, m2.shape
    (torch.Size([4, 256]), torch.Size([256, 5120]))
  • Result: The Python tiled matrix multiplication produces the same result as the previous versions.

    # Verify if the custom matmul_2d function produces the same result as PyTorch's built-in matrix multiplication
    torch.isclose(matmul_2d(m1s, m2s, tw=16), m1s@m2s).all()
    tensor(True)

Refactoring the Python Kernel

  • run_threads Function: Introduced to abstract the looping through threads within a tile.

    def run_threads(f, blockDim, *args, **kwargs):
        """
        Simulate thread execution in a 2D block.
    
        Args:
            f (callable): Function to be executed by each thread.
            blockDim (object): Object containing x and y dimensions of the block.
            *args: Variable length argument list to be passed to f.
            **kwargs: Arbitrary keyword arguments to be passed to f.
        """
        for i0 in range(blockDim.y):
            for i1 in range(blockDim.x):
                f(i0, i1, *args, **kwargs)  # Execute function for each thread
  • Refactored Kernel:

    def matmul_tiled_bk(blockIdx, blockDim, shared, m, n, out, h, w, k, tw):
        """
        Perform tiled matrix multiplication for a single block.
    
        Args:
            blockIdx (object): Block index in the grid.
            blockDim (object): Dimensions of the block.
            shared (list): Shared memory for the block.
            m (Tensor): First input matrix.
            n (Tensor): Second input matrix.
            out (Tensor): Output matrix.
            h (int): Height of the output matrix.
            w (int): Width of the output matrix.
            k (int): Common dimension of input matrices.
            tw (int): Tile width.
        """
        shar_sz = tw*tw
        ms, ns = shared[:shar_sz], shared[shar_sz:]  # Split shared memory for matrices m and n
    
        def get_rc(tr, tc):
            """Calculate global row and column indices from thread indices."""
            return blockIdx.y*blockDim.y + tr, blockIdx.x*blockDim.x + tc
    
        def fill_shared_tk(tr, tc, ph):
            """Fill shared memory with a tile of input matrices."""
            r, c = get_rc(tr, tc)
            # Load elements from matrix m, use 0 if out of bounds
            ms[tr*tw+tc] = m[tc + ph*tw + r*k] if r < h and (ph*tw+tc) < k else 0.
            # Load elements from matrix n, use 0 if out of bounds
            ns[tr*tw+tc] = n[(tr + ph*tw)*w + c] if c < w and (ph*tw+tr) < k else 0.
    
        def dotprod_tk(tr, tc):
            """Compute partial dot product for a tile."""
            r, c = get_rc(tr, tc)
            for i in range(tw):
                if r*w+c < len(out):
                    out[r*w+c] += ms[tr*tw+i] * ns[tw*i+tc]  # Accumulate dot product
    
        # Iterate over tiles in the k dimension
        for ph in range(int(math.ceil(k/tw))):
            run_threads(fill_shared_tk, blockDim, ph)  # Load tile into shared memory
            run_threads(dotprod_tk, blockDim)  # Compute partial dot products
    • Uses run_threads to simplify the code and make it more readable.
    • Separates the “fill shared memory” and “dot product” logic into distinct functions.
    def matmul_2d(m, n, tw=16):
        """
        Perform 2D matrix multiplication using tiled algorithm.
    
        Args:
            m (Tensor): First input matrix.
            n (Tensor): Second input matrix.
            tw (int, optional): Tile width. Defaults to 16.
    
        Returns:
            Tensor: Result of matrix multiplication.
        """
        h, k = m.shape
        k2, w = n.shape
        assert k == k2, "Size mismatch!"  # Ensure matrices can be multiplied
    
        output = torch.zeros(h, w, dtype=m.dtype)  # Initialize output matrix
        tpb = dim3(tw, tw)  # Define threads per block
        blocks = dim3(cdiv(w, tpb.x), cdiv(h, tpb.y))  # Calculate number of blocks needed
    
        # Launch kernel for tiled matrix multiplication
        blk_kernel2d_shar(matmul_tiled_bk, blocks, tpb, tw*tw*2,
                          m.flatten(), n.flatten(), output.flatten(),
                          h, w, k, tw=tw)
        return output
  • Result: The refactored kernel is functionally equivalent to the previous version.

    # Check shapes of input matrices
    m1s.shape, m2s.shape
    (torch.Size([4, 256]), torch.Size([256, 4]))
    # Verify the result of matmul_2d against PyTorch's built-in matrix multiplication
    torch.isclose(matmul_2d(m1s, m2s, tw=16), m1s@m2s).all()
    tensor(True)

CUDA-Like Python Implementation with Threads

  • Motivation: CUDA kernels don’t have explicit loops for threads; threads are executed concurrently.

  • Simulating Concurrent Threads: Python’s threading library is used to simulate concurrent thread execution.

    import threading
    from threading import Barrier, Thread  # For thread synchronization and creation
    from concurrent.futures import ThreadPoolExecutor  # For managing a pool of worker threads
    def g(x, sb):
        """
        A function that prints a number, its negative, and its tenfold value using a synchronization barrier.
    
        Args:
            x (int): The input number to be processed.
            sb (threading.Barrier): A synchronization barrier to coordinate threads.
    
        This function demonstrates the use of a barrier for thread synchronization.
        """
        print(x)
        sb.wait()  # Wait for all threads to reach this point
        print(-x)
        sb.wait()  # Wait again for all threads to reach this point
        print(x*10)
    # Define the number of threads to use
    num = 3
    
    # Create a Barrier object for synchronizing 'num' threads
    sb = Barrier(num)
    
    # Use a ThreadPoolExecutor to manage a pool of worker threads
    with ThreadPoolExecutor(num) as ex:
        # Execute the function g for each number in range(1, num+1) using the thread pool
        # The lambda function is used to pass both the number and the Barrier object to g
        # list() is used to force immediate execution of all tasks
        list(ex.map(lambda i: g(i, sb), range(1, num+1)))
    1
    2
    3
    -3
    -1
    -2
    10
    20
    30
  • Synchronization Barrier: A Barrier object is used to synchronize threads, ensuring that all threads complete the “fill shared memory” step before proceeding to the “dot product” step.

  • Kernel Runner:

    def blk_kernel2d_shar(f, blocks, tpb, sh_sz, *args, **kwargs):
        """
        Execute a 2D block kernel function with shared memory.
    
        Args:
            f (function): The kernel function to be executed.
            blocks (dim3): The number of blocks in x and y dimensions.
            tpb (dim3): Threads per block in x and y dimensions.
            sh_sz (int): Size of shared memory.
            *args: Variable length argument list for the kernel function.
            **kwargs: Arbitrary keyword arguments for the kernel function.
    
        This function creates a grid of threads to execute the given kernel function.
        """
        for i0 in range(blocks.y):
            for i1 in range(blocks.x):
                shar = torch.zeros(sh_sz)  # Shared memory for the block
                syncb = Barrier(tpb.y * tpb.x)  # Synchronization barrier for threads in a block
    
                # Create threads for each element in the block
                threads = [Thread(target=f, args=(dim3(i1,i0), dim3(p,o), tpb, shar, syncb, *args), kwargs=kwargs)
                           for o in range(tpb.y) for p in range(tpb.x)]
    
                # Start and join all threads in the block
                for tr in threads: tr.start()
                for tr in threads: tr.join()
    • Creates a synchronization barrier.
    • Creates a thread for each element within a tile.
    • Passes the block index, thread index, shared memory, synchronization barrier, and kernel arguments to each thread.
  • Kernel (Python with Threads):

    def matmul_tiled_bk(blockIdx, threadIdx, blockDim, shared, syncb, m, n, out, h, w, k, tw):
        """
        Perform tiled matrix multiplication for a single block.
    
        Args:
            blockIdx (dim3): Block index in the grid.
            threadIdx (dim3): Thread index within the block.
            blockDim (dim3): Dimensions of the block.
            shared (torch.Tensor): Shared memory for the block.
            syncb (threading.Barrier): Synchronization barrier for threads in the block.
            m (torch.Tensor): First input matrix (flattened).
            n (torch.Tensor): Second input matrix (flattened).
            out (torch.Tensor): Output matrix (flattened).
            h (int): Height of the first matrix.
            w (int): Width of the second matrix.
            k (int): Shared dimension of the matrices.
            tw (int): Tile width.
    
        This function computes a portion of the matrix multiplication result for a single block.
        """
        tc, tr = threadIdx.x, threadIdx.y
        r = blockIdx.y * blockDim.y + tr
        c = blockIdx.x * blockDim.x + tc
        shar_sz = tw * tw
        ms, ns = shared[:shar_sz], shared[shar_sz:]  # Split shared memory for two matrices
        p = 0.
    
        for ph in range(cdiv(k, tw)):
            # Load data into shared memory
            ms[tr*tw+tc] = m[tc + ph*tw + r*k] if r < h and (ph*tw+tc) < k else 0.
            ns[tr*tw+tc] = n[(tr + ph*tw)*w + c] if c < w and (ph*tw+tr) < k else 0.
            syncb.wait()  # Synchronize threads after loading data
    
            # Compute partial dot product
            for i in range(tw):
                p += ms[tr*tw+i] * ns[tw*i+tc]
            syncb.wait()  # Synchronize threads before next iteration
    
        if (r < h and c < w):
            out[r*w + c] = p  # Store the result in the output matrix
    • Calculates row and column in the output matrix based on block and thread indices.
    • Fills shared memory (same as before).
    • Waits at the synchronization barrier (syncb.wait()).
    • Performs the dot product using shared memory.
    • Waits at the synchronization barrier again.
    def matmul_2d(m, n, tw=16):
        """
        Perform 2D matrix multiplication using tiled algorithm.
    
        Args:
            m (torch.Tensor): First input matrix.
            n (torch.Tensor): Second input matrix.
            tw (int, optional): Tile width. Defaults to 16.
    
        Returns:
            torch.Tensor: Result of matrix multiplication.
    
        This function orchestrates the tiled matrix multiplication using block kernels.
        """
        h, k = m.shape
        k2, w = n.shape
        assert k == k2, "Size mismatch!"
    
        output = torch.zeros(h, w, dtype=m.dtype)
        tpb = dim3(tw, tw)  # Threads per block
        blocks = dim3(cdiv(w, tpb.x), cdiv(h, tpb.y))  # Number of blocks
    
        blk_kernel2d_shar(matmul_tiled_bk, blocks, tpb, tw*tw*2,
                          m.flatten(), n.flatten(), output.flatten(),
                          h, w, k, tw=tw)
        return output
  • Result: The Python implementation using threads simulates CUDA’s concurrent thread execution and produces the same result.

    # Verify the correctness of the implementation
    torch.isclose(matmul_2d(m1s, m2s, tw=8), m1s@m2s).all()
    tensor(True)

Implementing Tiling in CUDA

  • CUDA Kernel (Tiled): ChatGPT-generated CUDA code based on the tiled Python kernel.

    Code auto-generated by ChatGPT 4, using the following prompt:

    Convert the following python code to CUDA C, keeping formatting and variable names the same where possible. You can remove blockIdx, threadIdx, blockDim, shared from the argument list, since they’re already provided by CUDA. Change syncb.wait() to __syncthreads. Use extern __shared__ float shared[] to create the shared array. Use the C ternary operator to replace the Python equivalent where appropriate. If the Python code uses any non-standard functions, you can assume the same functions are also available to the translated C code with the same name and signature.

    The generated code worked first time, although we did some minor cleanups afterwards (e.g. renaming shared to ms).

  • Dynamic Shared Memory Allocation: Uses extern __shared__ float ms[]; to declare shared memory dynamically. The size is specified when launching the kernel.

    # CUDA kernel code for matrix multiplication
    cuda_src = cuda_begin + r'''
    /**
     * @brief CUDA kernel for matrix multiplication.
     * 
     * @param m Pointer to the first input matrix
     * @param n Pointer to the second input matrix
     * @param out Pointer to the output matrix
     * @param h Height of the first matrix
     * @param w Width of the second matrix
     * @param k Width of the first matrix / Height of the second matrix
     * @param tw Tile width for shared memory optimization
     */
    __global__ void matmul_k(float *m, float *n, float *out, int h, int w, int k, int tw) {
        int tc=threadIdx.x, tr=threadIdx.y;
        int r=blockIdx.y*blockDim.y+tr, c=blockIdx.x*blockDim.x+tc;
    
        extern __shared__ float ms[];  // Shared memory for the first matrix
        float *ns = &ms[tw*tw];  // Shared memory for the second matrix
    
        float p = 0.0f;  // Accumulator for the dot product
        for (int ph = 0; ph < cdiv(k,tw); ++ph) {
            int idx = ph*tw;
            // Load data into shared memory, with bounds checking
            ms[tr*tw + tc] = r<h && idx+tc<k ? m[ tc+idx + r*k ] : 0.0f;
            ns[tr*tw + tc] = c<w && idx+tr<k ? n[(tr+idx)*w + c] : 0.0f;
            __syncthreads();  // Ensure all threads have loaded data
            // Compute partial dot product
            for (int i=0; i<tw; ++i) p += ms[tr*tw + i] * ns[tw*i + tc];
            __syncthreads();  // Ensure all threads have finished computation
        }
        // Write result to global memory
        if (r<h && c<w) out[r*w + c] = p;
    }
    '''
    # PyTorch C++ extension for dynamic matrix multiplication
    cuda_src += r'''
    /**
     * @brief Perform matrix multiplication using CUDA.
     * 
     * @param m First input tensor
     * @param n Second input tensor
     * @return torch::Tensor Result of matrix multiplication
     */
    torch::Tensor matmul_dyn(torch::Tensor m, torch::Tensor n) {
        CHECK_INPUT(m); CHECK_INPUT(n);
        int h=m.size(0), w=n.size(1), k=m.size(1);
        TORCH_CHECK(k==n.size(0), "Size mismatch!");
        auto output = torch::zeros({h, w}, m.options());
    
        /*
        // Commented out section demonstrating basic idea of dynamic size calculation
        cudaDeviceProp devProp;
        CUDA_ERR(cudaGetDeviceProperties(&devProp, 0));
        int maxThreads = devProp.maxThreadsPerBlock;
        size_t requiredSize = static_cast<size_t>(maxThreads) * 2 * sizeof(float);
        size_t size = min(devProp.sharedMemPerBlock, requiredSize);
        int TW = std::sqrt(maxThreads);
        */
    
        // Fixed size configuration
        int TW = 16;  // Tile width
        size_t size = TW*TW * 2 * sizeof(float);  // Shared memory size
        dim3 tpb(TW,TW);  // Threads per block
        dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));  // Number of blocks
    
        // Launch CUDA kernel
        matmul_k<<<blocks,tpb,size>>>(
            m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k, TW);
        C10_CUDA_KERNEL_LAUNCH_CHECK();
        return output;
    }
    '''
    # Name of the function to be called
    fname = 'matmul_dyn'
    # Generate C++ function signature
    cpp_src = get_sig(fname, cuda_src)
    module = load_cuda(cuda_src, cpp_src, [fname], opt=True)
    # Test for correctness by comparing with PyTorch's built-in matrix multiplication
    torch.isclose(module.matmul_dyn(m1c,m2c), m1c@m2c).all()
    tensor(True, device='cuda:0')
  • Static Shared Memory Allocation: Declares shared memory arrays with fixed sizes at compile time (e.g., __shared__ float ms[tw][tw];).

    # CUDA kernel and PyTorch extension for efficient matrix multiplication
    cuda_src = cuda_begin + r'''
    constexpr int tw = 16;  // Tile width for shared memory optimization
    
    /**
     * CUDA kernel for matrix multiplication using shared memory tiling.
     *
     * @param m Pointer to the first input matrix
     * @param n Pointer to the second input matrix
     * @param out Pointer to the output matrix
     * @param h Height of the first input matrix and output matrix
     * @param w Width of the second input matrix and output matrix
     * @param k Width of the first input matrix / Height of the second input matrix
     */
    __global__ void matmul_ks(float *m, float *n, float *out, int h, int w, int k) {
        __shared__ float ms[tw][tw], ns[tw][tw];  // Shared memory for tiling
        int tc = threadIdx.x, tr = threadIdx.y;
        int r = blockIdx.y * blockDim.y + tr, c = blockIdx.x * blockDim.x + tc;
        float p = 0.0f;  // Accumulator for dot product
    
        // Iterate over tiles
        for (int ph = 0; ph < cdiv(k, tw); ++ph) {
            int idx = ph * tw;
            // Load data into shared memory, with bounds checking
            ms[tr][tc] = r < h && idx + tc < k ? m[tc + idx + r * k] : 0.0f;
            ns[tr][tc] = c < w && idx + tr < k ? n[(tr + idx) * w + c] : 0.0f;
            __syncthreads();  // Ensure all threads have loaded data
    
            // Compute partial dot product for this tile
            for (int i = 0; i < tw; ++i) p += ms[tr][i] * ns[i][tc];
            __syncthreads();  // Ensure computation is complete before next iteration
        }
    
        // Write result to global memory
        if (r < h && c < w) out[r * w + c] = p;
    }
    
    /**
     * PyTorch extension for static matrix multiplication using CUDA.
     *
     * @param m First input tensor
     * @param n Second input tensor
     * @return Resulting tensor from matrix multiplication
     */
    torch::Tensor matmul_static(torch::Tensor m, torch::Tensor n) {
        CHECK_INPUT(m); CHECK_INPUT(n);  // Validate input tensors
        int h = m.size(0), w = n.size(1), k = m.size(1);
        TORCH_CHECK(k == n.size(0), "Size mismatch!");  // Ensure matrices can be multiplied
    
        auto output = torch::zeros({h, w}, m.options());  // Initialize output tensor
    
        // Set up CUDA kernel launch parameters
        dim3 tpb(tw, tw);  // Threads per block
        dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));  // Number of blocks
    
        // Launch CUDA kernel
        matmul_ks<<<blocks, tpb>>>(m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k);
        C10_CUDA_KERNEL_LAUNCH_CHECK();  // Check for CUDA errors
    
        return output;
    }
    '''
    # Name of the function to be exported
    fname = 'matmul_static'
    
    # Generate C++ source code for the CUDA extension
    cpp_src = get_sig(fname, cuda_src)
    
    # Load the CUDA module
    module = load_cuda(cuda_src, cpp_src, [fname])
    
    # Verify correctness by comparing with PyTorch's built-in matrix multiplication
    torch.isclose(module.matmul_static(m1c, m2c), m1c @ m2c).all()
    tensor(True, device='cuda:0')
  • Synchronization: __syncthreads(); ensures all threads within a block have finished a step before proceeding to the next.

  • Performance:

    • Dynamic shared memory version is unexpectedly slower than the naive CUDA version.

      %%timeit -n 10
      # Benchmark the custom CUDA matrix multiplication
      module.matmul_dyn(m1c,m2c)
      torch.cuda.synchronize()  # Ensure CUDA operations are completed before timing
      3.2 ms ± 57.5 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
    • Static shared memory version with a fixed tile width is faster.

      %%timeit -n 10
      # Benchmark the custom matrix multiplication
      module.matmul_static(m1c, m2c)
      torch.cuda.synchronize()  # Ensure CUDA operations are complete before timing
      2.1 ms ± 23.9 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Dynamic Shared Memory Performance Issue and Solution (Update from the Future)

  • Cause: CUDA struggles to optimize dynamic shared memory allocation when the tile width is not known at compile time, leading to slower performance.

  • Solution: Use C++ templates to make the tile width a template parameter, enabling the compiler to generate optimized code for specific tile widths.

  • Implementation:

    # CUDA kernel for matrix multiplication
    cuda_src = cuda_begin + r'''
    template<int tw>
    __global__ void matmul_k(float *m, float *n, float *out, int h, int w, int k) {
        // Thread and block indices
        int tc = threadIdx.x, tr = threadIdx.y;
        int r = blockIdx.y * blockDim.y + tr, c = blockIdx.x * blockDim.x + tc;
    
        // Shared memory allocation
        extern __shared__ float ms[];
        float *ns = &ms[tw*tw];
    
        float p = 0.0f;  // Accumulator for dot product
    
        // Iterate over blocks of the input matrices
        for (int ph = 0; ph < cdiv(k,tw); ++ph) {
            int idx = ph * tw;
    
            // Load data into shared memory
            ms[tr*tw + tc] = r < h && idx+tc < k ? m[tc+idx + r*k] : 0.0f;
            ns[tr*tw + tc] = c < w && idx+tr < k ? n[(tr+idx)*w + c] : 0.0f;
    
            __syncthreads();  // Ensure all threads have loaded data
    
            // Compute partial dot product
            for (int i = 0; i < tw; ++i) {
                p += ms[tr*tw + i] * ns[tw*i + tc];
            }
    
            __syncthreads();  // Ensure all threads have used the data
        }
    
        // Write result to global memory
        if (r < h && c < w) {
            out[r*w + c] = p;
        }
    }
    '''
    # C++ wrapper function for the CUDA kernel
    cuda_src += r'''
    torch::Tensor matmul_dyn1(torch::Tensor m, torch::Tensor n) {
        CHECK_INPUT(m);
        CHECK_INPUT(n);
    
        // Get dimensions of input matrices
        int h = m.size(0), w = n.size(1), k = m.size(1);
    
        // Check if matrices can be multiplied
        TORCH_CHECK(k == n.size(0), "Size mismatch!");
    
        // Create output tensor
        auto output = torch::zeros({h, w}, m.options());
    
        int TW = 16;  // Thread block width (TODO: Calculate this dynamically)
        size_t size = TW*TW*2 * sizeof(float) + 1;  // Shared memory size
    
        // Define thread block and grid dimensions
        dim3 tpb(TW, TW);
        dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    
        // Lambda function to launch kernel
        auto f = [&](auto kf) {
            kf<<<blocks, tpb, size>>>(
                m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k
            );
        };
    
        // Launch kernel based on thread block size
        switch(TW) {
            case 8: f(matmul_k<8>); break;
            case 16: f(matmul_k<16>); break;
            case 32: f(matmul_k<32>); break;
            default: break;
        }
    
        // Check for CUDA errors
        C10_CUDA_KERNEL_LAUNCH_CHECK();
    
        return output;
    }
    '''
    • Define a C++ template function with tile width as a template parameter.
    • Support a fixed set of tile widths and compile a separate kernel version for each.
    • Use a lambda function to call the appropriate kernel version based on the chosen tile width.
  • Benefits: Enables optimized performance while allowing for some flexibility in tile width selection.

    %%time
    # Measure execution time of the following code
    
    # Define function name
    fname = 'matmul_dyn1'
    
    # Generate C++ function signature
    cpp_src = get_sig(fname, cuda_src)
    
    # Load CUDA module with optimization
    module = load_cuda(cuda_src, cpp_src, [fname], opt=True)
    
    # Get the function from the loaded module
    func = getattr(module, fname)
    CPU times: user 49.5 ms, sys: 63.7 ms, total: 113 ms
    Wall time: 41.1 s
    # Verify correctness of the custom matrix multiplication
    torch.isclose(func(m1c, m2c), m1c @ m2c).all()
    tensor(True, device='cuda:0')
    %%timeit -n 10
    # Measure execution time of the custom matrix multiplication
    func(m1c, m2c)
    
    # Ensure all CUDA operations are completed
    torch.cuda.synchronize()
    2.06 ms ± 51.2 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Implementing Tiling with Numba

  • Numba: An alternative library for writing CUDA code directly in Python.

    pip install numba
    pip install -U "numpy<2.1"
    from numba import cuda
    from numba.cuda import as_cuda_array as ca
  • CUDA Kernel (Numba): Python code decorated with @cuda.jit to indicate it’s a CUDA kernel.

    @cuda.jit
    def matmul_k_numba(m, n, out, tw):
        """
        Perform matrix multiplication on GPU using CUDA.
    
        This kernel function multiplies matrices 'm' and 'n', storing the result in 'out'.
        It uses shared memory and tiling for improved performance.
    
        Args:
        m (ndarray): First input matrix
        n (ndarray): Second input matrix
        out (ndarray): Output matrix to store the result
        tw (int): Tile width for shared memory optimization
    
        Note: This function is designed to be called from a host function, not directly.
        """
        # Get CUDA thread and block information
        cbi, cbd, tid = cuda.blockIdx, cuda.blockDim, cuda.threadIdx
        tc, tr = tid.x, tid.y
    
        # Calculate global row and column indices
        r, c = cbi.y * cbd.y + tr, cbi.x * cbd.x + tc
    
        # Get input matrix dimensions
        h, k = m.shape
        k2, w = n.shape
    
        # Allocate shared memory for tile-based computation
        shar = cuda.shared.array(0, dtype=np.float32)
        ms, ns = shar[:tw*tw], shar[tw*tw:2*tw*tw]  # Split shared memory for both input matrices
    
        # Initialize partial sum
        p = np.float32(0.0)
    
        # Iterate over tiles
        for ph in range(math.ceil(k/tw)):
            idx = ph * tw
    
            # Load data into shared memory, with boundary checks
            ms[tr*tw+tc] = m[r, tc+idx] if r < h and idx+tc < k else 0.
            ns[tr*tw+tc] = n[tr+idx, c] if c < w and idx+tr < k else 0.
    
            cuda.syncthreads()  # Ensure all threads have loaded data
    
            # Compute partial dot product for this tile
            for i in range(tw):
                p += ms[tr*tw+i] * ns[i*tw+tc]
    
            cuda.syncthreads()  # Ensure all threads have used the data before next iteration
    
        # Store the result if within output matrix bounds
        if r < h and c < w:
            out[r, c] = p
    • Shared Memory: cuda.shared.array creates dynamic shared memory arrays.
    • Synchronization: cuda.syncthreads() for thread synchronization.
  • Kernel Launching: Uses square brackets instead of triple angle brackets (e.g., kernel[blocks, threadsperblock, stream, shared_mem_size](...)).

    def matmul_2d_numba(m, n, tw=16):
        """
        Perform matrix multiplication using CUDA.
    
        This function prepares the CUDA kernel call for matrix multiplication.
    
        Args:
        m (Tensor): First input matrix (PyTorch tensor on CUDA)
        n (Tensor): Second input matrix (PyTorch tensor on CUDA)
        tw (int): Tile width for shared memory optimization (default: 16)
    
        Returns:
        Tensor: Result of matrix multiplication
    
        Raises:
        AssertionError: If input matrices have mismatched inner dimensions
        """
        h, k = m.shape
        k2, w = n.shape
        assert k == k2, "Size mismatch!"
    
        # Initialize output matrix
        out = torch.zeros(h, w, dtype=m.dtype, device=m.device)
    
        # Set up CUDA kernel parameters
        dyn_shared_mem_size = 2 * tw * tw * 4  # Size of shared memory in bytes
        tpb = tw, tw  # Threads per block
        blocks = cdiv(w, tpb[0]), cdiv(h, tpb[1])  # Calculate grid dimensions
    
        # Launch CUDA kernel
        matmul_k_numba[blocks, tpb, 0, dyn_shared_mem_size](ca(m), ca(n), ca(out), tw)
    
        return out
    # Verify correctness of the implementation
    torch.isclose(matmul_2d_numba(m1c, m2c), m1c@m2c).all()
    tensor(True, device='cuda:0')
  • Performance: The Numba version with dynamic shared memory is slower than the optimized CUDA C version but still provides CUDA-level speed.

    %%timeit -n 10
    # Benchmark the implementation
    matmul_2d_numba(m1c, m2c)
    torch.cuda.synchronize()  # Ensure all CUDA operations are completed before timing
    7.8 ms ± 80.7 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
  • Benefits:

    • Faster compilation times compared to PyTorch’s CUDA C/C++ approach.
      • Allows for faster iteration during development.
    • No need to flatten tensors (supports multidimensional indexing).
    • Access to tensor shape information within the kernel.
  • CUDA Simulator: Numba provides a built-in CUDA simulator by setting the environment variable NUMBA_ENABLE_CUDASIM=1.

    • Executes CUDA code as pure Python on the CPU, allowing for debugging and experimentation with small datasets.
  • Development Workflow:

    1. Develop and debug CUDA kernels in Numba with the simulator enabled.
    2. Disable the simulator to run the code on the GPU.
    3. Optionally, convert the Numba code to CUDA C/C++ using ChatGPT for deployment.

Q&A Session

  • Shipping Numba Kernels and AOT Compilation:
    • AOT Compilation: Numba’s AOT was discussed as a potential deployment simplification solution.
    • AOT Deprecation: Numba’s AOT is deprecated (February 2024), with a replacement planned but unspecified.
  • Performance Comparisons and Optimization Opportunities:
    • Optimization Tools: TVM and Mojo GPU’s auto-tune (expected late February/March 2024) were mentioned as potential optimization aids.
  • PyTorch’s Matrix Multiplication Implementation:
    • PyTorch primarily uses cuBLAS.
    • Torch Compile and Inductor: Torch Compile’s experimental mode (torch.inductor.config) was mentioned as a potential alternative backend.
    • Profiling for Backend Identification: PyTorch’s profiler can reveal the backend used through function signatures.
  • Compilation Speed and Iterative Development:
    • Compilation Speed Importance: Fast compilation was emphasized as crucial for iterative development.
    • Fast Compilation Benefits: Fast compilation, aided by tools like the CUDA simulator and Numba’s CUDA JIT, enhances productivity and reduces debugging time.
  • ChatGPT’s Role in CUDA Development:
    • ChatGPT’s Code Generation Capabilities: ChatGPT is useful for code conversion and API usage but less effective for novel algorithms.
  • Numba vs. Triton:
    • Different Purposes: Numba and Triton were recognized as valuable tools with distinct strengths, suitable for different use cases. Triton’s limitations in expressing certain CUDA constructs (e.g., 4-bit discretization) were noted.
    • Complementary Tools: Numba and Triton were seen as complementary, each offering unique advantages.

About Me:
  • I’m Christian Mills, a deep learning consultant specializing in computer vision and practical AI implementations.
  • I help clients leverage cutting-edge AI technologies to solve real-world problems.
  • Learn more about me or reach out via email at [email protected] to discuss your project.