Fused Softmax Tutorial

This tutorial demonstrates how to implement Fused Softmax using Wave.

Softmax Function

The softmax function is defined as:

\[\mathrm{softmax}(\mathbf{z})_i = \frac{e^{z_i}}{\sum_{j=1}^{N} e^{z_j}}\]

Here, \(\mathbf{z}\) is a vector where each value \(z_i\) is the output of a neural network. The softmax function converts these outputs into probabilities by exponentiating each \(z_i\) and dividing by the sum of all exponentiated outputs.

This function is typically used as the final step of a classification neural network to represent the probability distribution over predicted classes.

Note: Rows = samples, Columns = classes

Fused Softmax

A fused softmax implementation combines the exponentiation, summation, and normalization into a single kernel. This avoids intermediate memory reads/writes, reducing latency and improving performance.

The function:

test_fused_softmax(...)

is defined in wave_e2e_test.py. It loads the softmax kernel with the correct constraints and compares the result with PyTorch’s built-in softmax function. Input shapes are defined in test_param.json.

Key GPU Programming Terms in Wave

  • Thread: Smallest unit of execution in GPU.

  • Wave: A group of threads (typically 32 or 64) that processes a tile. Equivalent to a CUDA warp.

  • Workgroup: A group of waves. The grid is divided into workgroups, each of which may include one or more waves.

Constraint Creation

M = tkl.sym.M
N = tkl.sym.N
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE

wave_size = 64
BLOCK_M = 1

constraints: list[tkw.Constraint] = [
    tkw.HardwareConstraint(
        threads_per_wave=wave_size,
        vector_shapes={M: BLOCK_M, N: N},
    )
]

constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]

Explanation:

  • \(M\) and \(N\) are symbolic values representing rows and columns of the grid.

  • HardwareConstraint specifies: - 64 threads per wave - 1 wave per block in each dimension - each wave handles a tile of shape [1, N] (or 1 row)

  • WorkgroupConstraint(M, BLOCK_M, 1) officially maps the M dimension to the Y-axis (the rows) and partitions the grid into workgroups of size BLOCK_M.

  • WaveConstraint(M, BLOCK_M) once again ensures that each wave covers a tile of 1 row.

Note that here each workgroup contains one wave and both operate on the same partition of data, but that will not always be the case.

Kernel Creation

Now that constraints are defined, we implement the kernel to match them.

Step 1: Define the kernel

@tkw.wave(constraints)
def test(
    a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32],
    b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32]
):

Each wave executes this function independently on its designated tile.

Note

Note 1: Kernels written in Wave are generally executed at the wave level, meaning each wave processes its own tile independently. However, Wave provides certain functionalities that allow communication between waves within the same workgroup when needed.

Note

Note 2: The a and b tile inputs are typed as [M, N], but these M and N symbols represent the tile dimensions that each wave is assigned, not the full input grid dimensions (even if the symbols are named the same). For example, suppose the original input grid has size M = 256, and you set a WaveConstraint with BLOCK_M = 64. Then, each wave receives a tile with M = 64, and the value passed into the kernel for M will be 64 — not the full grid size of 256.

In this particular case, the N dimension remains unchanged between the grid and the tile, meaning each wave processes all columns.

Step 2: Write the kernel body

val = tkw.read(a)
row_max = tkw.max(val, dim=N)
row_max_bcast = tkw.broadcast(row_max, [M, N])
val -= row_max_bcast
val = tkw.exp(val)
denominator = tkw.sum(val, dim=N)
denom_broadcast = tkw.broadcast(denominator, [M, N])
val = val / denom_broadcast
tkw.write(val, b)

Explanation:

  • read: loads a row from memory.

  • max: computes max value across the row - which we then subtract from each value in row. This is a slight addition to the original softmax equation to improve numerical stability.

  • broadcast: replicates max or sum values across the row.

  • exp: applies exponentiation.

  • sum: computes denominator.

  • write: stores the result to output buffer.

Each wave performs softmax on its assigned row.