GEMM Tutorial¶
This tutorial demonstrates how to implement a high-performance matrix multiplication (GEMM) kernel using Wave. We’ll walk through the implementation step by step, explaining the key concepts and optimizations.
Overview¶
The GEMM kernel we’ll implement computes C = A @ B.T, where:
A is an M×K matrix in f16
B is an N×K matrix in f16
C is an M×N matrix in f32
We’ll use Wave’s symbolic programming model and hardware-aware abstractions to create an efficient implementation.
Implementation¶
First, we need to import the necessary modules and define our symbolic dimensions:
from wave_lang.kernel._support.indexing import sym
from wave_lang.kernel._support.dtype import f16, f32
from wave_lang.kernel.lang.wave_types import *
from wave_lang.kernel.lang.global_symbols import *
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
import wave_lang.kernel.lang as tkl
import wave_lang.kernel.wave as tkw
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
import torch
# Define symbolic dimensions for our matrices
M = sym.M # Rows of A and C
N = sym.N # Rows of B and columns of C
K = sym.K # Columns of A and B
# Define workgroup tile sizes
BLOCK_M = sym.BLOCK_M
BLOCK_N = sym.BLOCK_N
BLOCK_K = sym.BLOCK_K
# Define the address space for our memory
ADDRESS_SPACE = sym.ADDRESS_SPACE
Now, let’s define our GEMM kernel with appropriate constraints:
# Define constraints for the kernel
constraints = [
# Distribute M dimension across workgroups with tile size BLOCK_M
tkw.WorkgroupConstraint(M, BLOCK_M, 0),
# Distribute N dimension across workgroups with tile size BLOCK_N
tkw.WorkgroupConstraint(N, BLOCK_N, 1),
# Tile the K dimension for reduction
tkw.TilingConstraint(K, BLOCK_K),
# Further distribute M among waves with a tile size of BLOCK_M / 2
tkw.WaveConstraint(M, BLOCK_M / 2),
# Further distribute N among waves with a tile size of BLOCK_N / 2
tkw.WaveConstraint(N, BLOCK_N / 2),
# Hardware-specific constraints
tkw.HardwareConstraint(
threads_per_wave=64,
mma_type=tkw.MMAType.F32_16x16x16_F16
)
]
@tkw.wave(constraints)
def gemm(
a: Memory[M, K, ADDRESS_SPACE, f16], # Input matrix A
b: Memory[N, K, ADDRESS_SPACE, f16], # Input matrix B
c: Memory[M, N, GLOBAL_ADDRESS_SPACE, f32], # Output matrix C
):
# Initialize the accumulator register with zeros
c_reg = Register[M, N, f32](0.0)
# Iterate over the K dimension to compute the dot product
@tkw.iterate(K, init_args=[c_reg])
def repeat(acc: Register[M, N, f32]) -> Register[M, N, f32]:
# Load elements from A and B
a_reg = tkw.read(a)
b_reg = tkw.read(b)
# Compute matrix multiplication and accumulate
acc = tkw.mma(a_reg, b_reg, acc)
return acc
# Store the final result to C
tkw.write(repeat, c)
Testing the Implementation¶
Let’s create a test function to verify our GEMM implementation:
def test_gemm():
# Create test matrices
m, n, k = 128, 256, 128 # Small dimensions for testing
# Initialize input matrices with random values
torch.manual_seed(0)
a = torch.randn(m, k, dtype=torch.float16, device="cuda")
b = torch.randn(n, k, dtype=torch.float16, device="cuda")
c = torch.zeros(m, n, dtype=torch.float32, device="cuda")
# Set hyperparameters for compilation
hyperparams = {
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
BLOCK_M: 64,
BLOCK_N: 64,
BLOCK_K: 32,
M: m,
N: n,
K: k,
}
# Compile the kernel
options = WaveCompileOptions(
subs=hyperparams,
)
options = set_default_run_config(options)
compiled_gemm = wave_compile(options, gemm)
# Run the GEMM kernel
compiled_gemm(a, b, c)
# Verify the result using PyTorch's matmul
expected = torch.matmul(a, b.t())
# Check if results are close (accounting for floating-point precision)
assert torch.allclose(c.to(torch.float16), expected, rtol=1e-2, atol=1e-2), \
f"GEMM result doesn't match expected output\nMax difference: {(c - expected).abs().max()}"
print("GEMM test passed!")
Key Components¶
Memory Types and Data Types:
Memory[M, K, ADDRESS_SPACE, f16]defines a matrix in memory with dimensions M×Kf16andf32specify half and single precision floating-point typesDifferent address spaces (shared and global) for optimal memory access
Wave Language Features:
@tkw.wave()decorator with constraints defines the kernel’s execution parameters@tkw.iteratecreates an iteration loop over the K dimensionRegisterrepresents values in registers during computationtkw.readandtkw.writehandle memory operationstkw.mmaperforms matrix multiply-accumulate operations
Constraints:
Workgroup Constraints: Distribute computation across workgroups - M dimension is distributed with tile size BLOCK_M - N dimension is distributed with tile size BLOCK_N
Wave Constraints: Enable wave-level parallelism - M and N dimensions are further parallelized within workgroups
Hardware Constraints: Specify GPU-specific parameters - 64 threads per wave - 2x2x1 waves per block - F32_16x16x16_F16 matrix multiply-accumulate operation
Memory Hierarchy:
Input matrices (a, b) are in shared memory for fast access
Output matrix (c) is in global memory
Intermediate results are kept in registers
Computation Flow:
Initialize accumulator register with zeros
Iterate over K dimension to perform reduction
Load tiles from shared memory
Perform matrix multiplication and accumulation
Write final result to global memory
Performance Considerations¶
Tile Size Selection:
Choose tile sizes that maximize memory locality
Consider hardware constraints (shared memory size, register file size)
Balance between parallelism and resource usage
Example values: BLOCK_M=64, BLOCK_N=64, BLOCK_K=32
Memory Access Patterns:
Use shared memory for frequently accessed data (input matrices)
Minimize bank conflicts in shared memory
Align memory accesses for better coalescing
Consider mixed precision (f16 inputs, f32 accumulation)
Wave Organization:
Distribute work evenly across waves
Use hardware-specific wave sizes (64 threads per wave)
Optimize for the target GPU architecture
Consider wave-level parallelism for both M and N dimensions
Testing and Validation:
Use small test cases for initial verification
Compare against PyTorch’s implementation
Account for floating-point precision differences
Use appropriate error tolerances (rtol=1e-2, atol=1e-2)
For more advanced optimizations and techniques, see the Wave System Architecture documentation.