Wave Ops

class wave_lang.kernel.ops.wave_ops.Abs(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'abs'
class wave_lang.kernel.ops.wave_ops.Add(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'add'
class wave_lang.kernel.ops.wave_ops.Allocate(shape: tuple[~sympy.core.expr.Expr], distributed_shape: tuple[~sympy.core.expr.Expr], dtype: ~wave_lang.kernel._support.dtype.DataType, address_space: ~wave_lang.kernel.lang.kernel_buffer.AddressSpace = $SHARED_ADDRESS_SPACE, padding: int = 0, parent: ~torch.fx.node.Node | None = None, offset: ~sympy.core.expr.Expr | None = None, tail_padding: int = 0)

Represents an allocation in an address space (such as shared memory).

address_space: AddressSpace = $SHARED_ADDRESS_SPACE
property allocation_size: Expr

Returns the full size of the allocation in bytes including all padding.

distributed_shape: tuple[Expr]
dtype: DataType
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

offset: Expr | None = None
padding: int = 0
parent: Node | None = None
shape: tuple[Expr]
tail_padding: int = 0
tkw_op_name: str = 'allocate'
property type: Memory
property unpadded_dims: dict[Symbol, Expr]
property unpadded_shape: tuple[Expr]
class wave_lang.kernel.ops.wave_ops.And(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'and_'
class wave_lang.kernel.ops.wave_ops.ApplyExpr(register_: 'fx.Proxy | Sequence[fx.Proxy]', expr: 'Callable')
expr: Callable
property indexing_dims: list[Symbol]
register_: Proxy | Sequence[Proxy]
tkw_op_name: str = 'apply_expr'
property type: Register
class wave_lang.kernel.ops.wave_ops.Atan2(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'atan2'
class wave_lang.kernel.ops.wave_ops.AtomicAddOp(lhs: 'Any', rhs: 'Any', elements_per_thread: 'Optional[Any]' = None, mapping: 'Optional[IndexMapping]' = None, mapping_dynamic_vals: 'tuple[fx.Node, ...]' = ())
tkw_op_name: str = 'atomic_add'
class wave_lang.kernel.ops.wave_ops.AtomicOp(lhs: Any, rhs: Any, elements_per_thread: Any | None = None, mapping: IndexMapping | None = None, mapping_dynamic_vals: tuple[Node, ...] = ())

Represents an atomic operation in the graph. Takes in Register and Memory as inputs and writes the modified value back on to the buffer. Mapping attribute maps the index from wave kernel to the shared memory index the wavegroup operates on.

elements_per_thread: Any | None = None
property has_side_effects: bool
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

mapping: IndexMapping | None = None
mapping_dynamic_vals: tuple[Node, ...] = ()
property memory_type: Memory
tkw_op_name: str = 'atomic_min'
class wave_lang.kernel.ops.wave_ops.BinaryOpBase(lhs: Any, rhs: Any)

Represents an elementwise binary python operator.

DTYPE requirement: lhs and rhs needs to have the same dtpye. Shape requirement: lhs and rhs either have same shape or

their shape must be broadcastable to one another.

property indexing_dims: list[Symbol]
infer_shape() Any
lhs: Any
property py_operator: str
rhs: Any
class wave_lang.kernel.ops.wave_ops.BinaryPyOp(lhs: 'Any', rhs: 'Any')
infer_type(*args)

Infer the type of this operator using the types of its arguments.

class wave_lang.kernel.ops.wave_ops.BitcastOp(arg: Node, dtype: DataType)

Represents a bitcast operation.

arg: Node
dtype: DataType
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

property scale_factor
tkw_op_name: str = 'bitcast'
class wave_lang.kernel.ops.wave_ops.BoundsCheck(index_exprs: dict[Expr, IndexSequence], bounds: dict[Symbol, Expr], mapping: IndexMapping | None = None, mapping_dynamic_vals: tuple[Node, ...] = (), mask_bounds: dict[Symbol, Expr] | None = None)

Checks if a dimension is within a bound.

bounds: dict[Symbol, Expr]
property has_side_effects: bool
index_exprs: dict[Expr, IndexSequence]
mapping: IndexMapping | None = None
mapping_dynamic_vals: tuple[Node, ...] = ()
mask_bounds: dict[Symbol, Expr] | None = None
tkw_op_name: str = 'bounds_check'
class wave_lang.kernel.ops.wave_ops.Broadcast(arg: Node, target_shape: Sequence[Symbol] = None)

Represents a Broadcast operation.

arg: Source tensor/value to broadcast target_shape: symbolic target broadcast shape.

arg: Node
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

target_shape: Sequence[Symbol] = None
tkw_op_name: str = 'broadcast'
class wave_lang.kernel.ops.wave_ops.CastOp(arg: Node, dtype: DataType)

Represents a cast operation.

arg: Node
dtype: DataType
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

tkw_op_name: str = 'cast'
class wave_lang.kernel.ops.wave_ops.Cbrt(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'cbrt'
class wave_lang.kernel.ops.wave_ops.ComparisonPyOp(lhs: 'Any', rhs: 'Any')
infer_type(*args)

Infer the type of this operator using the types of its arguments.

class wave_lang.kernel.ops.wave_ops.Conditional(condition: Proxy | Expr, subgraph_name: str, implicit_captures: Sequence[Proxy], else_return: Sequence[Register] | None = None)

The optional else_return argument must match the type of any return statements in the conditional body. The else_return are the default return value for the else branch.

condition: Proxy | Expr
else_return: Sequence[Register] | None = None
implicit_captures: Sequence[Proxy]
property indexing_dims: list[Symbol] | list[list[Symbol]]
property init_args

Alias for else_return to match Iterate interface for shared processing code.

subgraph_name: str
tkw_op_name: str = 'conditional'
class wave_lang.kernel.ops.wave_ops.Cos(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'cos'
class wave_lang.kernel.ops.wave_ops.Cumsum(arg: Node | list[Node], init: Node | None = None, dim: Symbol | None = None, block: bool | None = False)

NewSubclass(arg: ‘fx.Node | list[fx.Node]’, init: ‘Optional[fx.Node]’ = None, dim: ‘Optional[IndexSymbol]’ = None, block: ‘Optional[bool]’ = False)

tkw_op_name: str = 'cumsum'
class wave_lang.kernel.ops.wave_ops.CustomOp

Base class for all custom fx nodes.

Fields with compare=False are infrastructure or scheduling artifacts and do not participate in semantic equality used for trace equivalence.

add_to_graph(region_graph: RegionGraph, type: Any = None, loc: FileLineColInfo | StackTraceInfo | None = None, tag: str | None = None) Node
copy(new_name: str | None = None, new_graph: ~torch.fx.graph.Graph | None = None, arg_transform: ~typing.Callable[[~typing.Any], ~typing.Any] | None = <function CustomOp.<lambda>>, anchor: ~torch.fx.node.Node | None = None) Self

Returns a duplicate of this node.

copy_core_attributes(new_node: Node)

Copy core attributes from the current node to the new node.

classmethod create(graph: Graph, *args, type: Any = None, extra_attrs: dict[str, Any] | None = None, **kwargs) CustomOpT

Create a CustomOp instance and its FX node together.

This classmethod properly instantiates the CustomOp with its semantic fields, creates the FX node, and links them together. Useful when building FX graphs programmatically (e.g., from MLIR).

Parameters:
  • graph – FX graph to add the node to.

  • *args – Positional arguments for the dataclass constructor.

  • type – Optional type to set on the node.

  • extra_attrs – Additional attributes to set on the fx_node after creation (e.g., index, vector_shapes). These are not passed to the dataclass constructor.

  • **kwargs – Keyword arguments forwarded to the dataclass constructor.

Returns:

The created CustomOp instance with fx_node and graph fields populated.

Example

register = NewRegister.create(

graph, dims, dtype, init_value, type=Register[(M, N, f32)], extra_attrs={“vector_shapes”: {M: 16, N: 16}},

)

custom_string(value_map: dict[str, str]) str
erase()

Erase the current node from the graph where it exists.

property expanded_dims: dict[Symbol, int]

During expansion each node is expanded along its indexing dimensions. The expanded_dims property stores the dimensions along which the node has been expanded as well as the scaling along that dimension.

For example, a node with indexing dimensions [M, N] with dimensional scaling {M: 2, N: 2}, will be expanded to 4 nodes, with each expanded node mapping to the following expanded_dims {M: 0, N: 0}, {M: 0, N: 1}, {M: 1, N: 0}, {M: 1, N: 1}.

classmethod from_fx_node(node: Node) CustomOpT
fx_node: Node | None = None
get_node_arg_index(arg: CustomOp) CustomOp | list[CustomOp] | None
graph: Graph | None = None
classmethod handle(graph: RegionGraph, *args, **kwargs) Node
property has_side_effects: bool
property index: dict[Symbol, IndexSequence] | None
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

property location: FileLineColInfo | StackTraceInfo | None
property name: str
property node_args: dict[int, Any]

Returns the args to this custom op using subclasses of CustomOp if possible.

property pre_expansion_id: int
replace_all_uses_with(new_node: CustomOp | Node)

Replace all uses of the current node with the new node.

replace_all_uses_with_except(new_node: CustomOp | Node, except_nodes: list[CustomOp])

Replace all uses of the current node with the new node except for the nodes in except_nodes.

replace_uses_with(new_node: CustomOp | Node, *, graph: Graph | None = None, propagate_location: bool = True) None

Replace uses of this node, optionally restricted to one graph.

When graph is None, this replaces every use of the current node. When graph is provided, only uses in that specific fx.Graph are replaced.

This matters for nested regions: the same FX node may be referenced from multiple graphs at the same time, for example from a nested subgraph and from sibling or outer graphs. A graph-scoped replacement therefore does not imply that the current node becomes globally dead. After the call, the node may still have uses in other graphs that were intentionally left untouched.

replacement_location_propagate(new_node: CustomOp | Node | list[Node])

Set the new_node location if it doesn’t have one.

property rrt
property scheduling_parameters
property tag: set[str] | None
tkw_op_name: str = 'unknown'
transform_index(index: dict[Symbol, IndexSequence]) dict[Symbol, IndexSequence]

Transform the index of the node based on the provided mapping.

transform_index_backwards(index: dict[Symbol, IndexSequence], arg: Node) dict[Symbol, IndexSequence]

Transform the index of the node when propagating index backwards, i.e. from node to its arguments.

property type: Any
property unroll_iteration: int | None

Returns the unroll iteration index for this node, or None if not unrolled.

update_arg(idx_or_name: int | str | Node, value: CustomOp | Node)

Update an operand or named field while keeping the underlying fx.Node consistent for both positional arguments and keyword arguments.

property users: list[Any]

Returns the users of this custom op using subclasses of CustomOp if possible.

property vector_shapes: dict[Symbol, int]
class wave_lang.kernel.ops.wave_ops.DebugLog(register_: ~torch.fx.proxy.Proxy, label: str | None = None, extra_iteration_dimensions: list[tuple[~sympy.core.symbol.Symbol, ~sympy.core.symbol.Symbol, int]] | None = None, mapping: ~wave_lang.kernel.lang.wave_types.IndexMapping | None = None, mapping_dynamic_vals: tuple[~torch.fx.node.Node, ...] = (), printer: ~typing.Callable[[str, ~typing.Any], ~typing.Any] | None = <built-in function print>, handler: ~typing.Callable[[dict[str, ~typing.Any]], ~typing.Any] | None = None)

An op for debugging. Represents a write to an implicit global memory location. The kernel will implicitly have an extra memory input added that will be injected by the Python kernel launcher. The memory can be accessed by passing an an extra keyword to kernel invokation “debug_logs” with an empty dictionary. The dictionary will be mutated, adding the label`s as keys that map to nested dictionaries. The nested dictionaries have a `value field with the log tensor, and other keys (eg. symbolic_shape).

IE the debug_logs dictionary will look like: {LABEL: {“value”: LOG_TENSOR, (other metadata keys) …}}

Note that the logs collected in the debug_logs field, or handled by printer or handler represent a global view of the log after all writes, not limited to any one wave or loop iteration.

The optional printer argument should be a function that accepts a string (the log’s label) and the value of the log itself (a Torch tensor). The default value for this is print, though note that it will probably print an abbreviated view of the global tensor.

The optional handler argument should be a function that accepts the whole debug_logs object (IE all logs, not just one). The handler function gives a way to specify something like a viewer for all logs, but specify it inline among print functions rather than separately.

The optional extra_iteration_dimensions argument allows you to add extra dimensions to capture values from multiple iterations of a loop. It takes a list of tuples, where each tuple contains (dimension_name, iteration_axis, max_iterations). The dimension_name must be a unique symbol, and will be the name of the dimension in the symbolic shape of the output. The iteration_axis must be the axis of an Iterate operation, IE the dimension being reduced in the iteration. The max_iterations argument is a positive integer that represents the maximum number of iterations to store. Note that if you give too few, later iterations will currently overwrite the final iteration slot.

The API and semantics of this operation are not yet stable, but since it is just a debugging tool, you want to take any debug logging out of your kernel before shipping it anyway.

extra_iteration_dimensions: list[tuple[Symbol, Symbol, int]] | None = None
handler: Callable[[dict[str, Any]], Any] | None = None
property has_side_effects: bool
infer_type(*args)

Infer the type of this operator using the types of its arguments.

label: str | None = None
mapping: IndexMapping | None = None
mapping_dynamic_vals: tuple[Node, ...] = ()
property memory: Proxy | None
printer(*, sep=' ', end='\n', file=None, flush=False)

Prints the values to a stream, or to sys.stdout by default.

sep

string inserted between values, default a space.

end

string appended after the last value, default a newline.

file

a file-like object (stream); defaults to the current sys.stdout.

flush

whether to forcibly flush the stream.

register_: Proxy
tkw_op_name: str = 'debug_log'
class wave_lang.kernel.ops.wave_ops.Eq(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'eq'
class wave_lang.kernel.ops.wave_ops.Exp(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'exp'
class wave_lang.kernel.ops.wave_ops.Exp2(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'exp2'
class wave_lang.kernel.ops.wave_ops.Extract(register_: Proxy, offset: Expr | int)

Op Rationale:

Extract is an op used to represent extracting of a scalar from TKW’s 1-D vector on the specified index.

This can also be viewed as indexing/slicing on the fastest dimension. Hence, the semantic of this op is designed to see itself as a reduction on the indexed/fastest dimension.

infer_type(*args)

Infer the type of this operator using the types of its arguments.

offset: Expr | int
register_: Proxy
tkw_op_name: str = 'extract'
class wave_lang.kernel.ops.wave_ops.ExtractSlice(register_: 'fx.Proxy', offset: 'tuple[IndexExpr]', size: 'tuple[IndexExpr]', stride: 'tuple[IndexExpr]')
offset: tuple[Expr]
property rank: int
register_: Proxy
size: tuple[Expr]
stride: tuple[Expr]
tkw_op_name: str = 'extract_slice'
property type: Register
class wave_lang.kernel.ops.wave_ops.Floordiv(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'floordiv'
class wave_lang.kernel.ops.wave_ops.GatherToLDS(src: Memory, dst: Memory, src_index: dict[Symbol, IndexSequence], dst_index: dict[Symbol, IndexSequence], dtype: DataType, elements_per_thread: Expr | int | None, src_mapping: IndexMapping | None, dst_mapping: IndexMapping | None, src_bounds: dict[Symbol, Expr] | None, src_mapping_dynamic_vals: tuple[Node, ...] = (), dst_mapping_dynamic_vals: tuple[Node, ...] = ())

Represents an instruction that performs direct load from global to lds. Source node points to the global memory to load from and the destination node points to shared memory.

dst: Memory
dst_index: dict[Symbol, IndexSequence]
dst_mapping: IndexMapping | None
dst_mapping_dynamic_vals: tuple[Node, ...] = ()
dtype: DataType
elements_per_thread: Expr | int | None
property has_side_effects: bool
src: Memory
src_bounds: dict[Symbol, Expr] | None
src_index: dict[Symbol, IndexSequence]
src_mapping: IndexMapping | None
src_mapping_dynamic_vals: tuple[Node, ...] = ()
tkw_op_name: str = 'gather_to_lds'
class wave_lang.kernel.ops.wave_ops.Ge(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'ge'
class wave_lang.kernel.ops.wave_ops.GetResult(value: 'fx.Node', res_idx: 'int')
property distributed_shape
property index: dict[Symbol, IndexSequence]
property indexing_dims: list[Expr]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

res_idx: int
tkw_op_name: str = 'get_result'
value: Node
class wave_lang.kernel.ops.wave_ops.Gt(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'gt'
class wave_lang.kernel.ops.wave_ops.Invert(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'invert'
class wave_lang.kernel.ops.wave_ops.IterArg(_name: str, _type: Type[DataType] | Type[KernelBuffer] | None = None)

Represents a specific placeholder node in the graph that is an iter arg of a reduction node. IterArgs can be of type Register or Memory with a Shared memory address space.

property distributed_shape
infer_type(*args)

Infer the type of this operator using the types of its arguments.

property iter_idx
parent_op()
class wave_lang.kernel.ops.wave_ops.Iterate(axis: 'IndexSymbol', init_args: 'Sequence[Any]', subgraph_name: 'str', implicit_captures: 'Sequence[fx.Proxy]', step: 'int' = 1, start: 'Optional[IndexExpr]' = None, condition: 'Optional[IndexExpr]' = None)
axis: Symbol
condition: Expr | None = None
property count: int | None
implicit_captures: Sequence[Proxy]
property index: list[dict[Symbol, IndexSequence] | None]

Collect indices from the subgraph’s output return values.

Always returns a list with one entry per return value. Uses getattr with a None default so this is safe to call on partially-populated graphs where indices have not been propagated yet.

property indexing_dims: list[Symbol] | list[list[Symbol]]
init_args: Sequence[Any]
start: Expr | None = None
step: int = 1
subgraph_name: str
tkw_op_name: str = 'iterate'
class wave_lang.kernel.ops.wave_ops.Le(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'le'
class wave_lang.kernel.ops.wave_ops.Log10(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'log10'
class wave_lang.kernel.ops.wave_ops.Log2(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'log2'
class wave_lang.kernel.ops.wave_ops.Lt(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'lt'
class wave_lang.kernel.ops.wave_ops.MMA(lhs: 'fx.Node', rhs: 'fx.Node', acc: 'fx.Node', mma_type: "Optional['MMAType'] | 'GenericDot'" = None)
acc: fx.Node
property acc_index: dict[Symbol, IndexSequence]
property acc_type: Memory
custom_string(value_map: dict[str, str]) str
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

lhs: fx.Node
property lhs_index: dict[Symbol, IndexSequence]
property lhs_type: Memory
mma_type: 'MMAType' | None | 'GenericDot' = None
operand_index(operand_map: dict[Symbol, int], shape: list[Expr]) dict[Symbol, IndexSequence]
property reduction_dim: Symbol
rhs: fx.Node
property rhs_index: dict[Symbol, IndexSequence]
property rhs_type: Memory
tkw_op_name: str = 'mma'
class wave_lang.kernel.ops.wave_ops.MMABase
class wave_lang.kernel.ops.wave_ops.Max(arg: Node | list[Node], init: Node = None, dim: Any | None = None, block: bool | None = False)

NewSubclass(arg: ‘fx.Node | list[fx.Node]’, init: ‘fx.Node’ = None, dim: ‘Optional[Any]’ = None, block: ‘Optional[bool]’ = False)

tkw_op_name: str = 'max'
class wave_lang.kernel.ops.wave_ops.Maximum(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'maximum'
class wave_lang.kernel.ops.wave_ops.MemoryAccessFlags(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

Flags for memory access operations (read/write). Maps to LLVM load/store attributes.

i.e. flags = MemoryAccessFlags.VOLATILE | MemoryAccessFlags.NONTEMPORAL

NONE = 0
NONTEMPORAL = 2
VOLATILE = 1
class wave_lang.kernel.ops.wave_ops.MemoryCounterWait(load: int | None = None, store: int | None = None, ds: int | None = None, exp: int | None = None)

Wait for the specified counters to be less-than or equal-to the provided values before continuing.

Emits: amdgpu.memory_counter_wait with specified counters

ds: int | None = None
exp: int | None = None
property has_side_effects: bool
load: int | None = None
store: int | None = None
tkw_op_name: str = 'memory_counter_wait'
class wave_lang.kernel.ops.wave_ops.MemoryCounterWaitBarrier(load: int | None = None, store: int | None = None, ds: int | None = None, exp: int | None = None)

Wait for the specified counters to be less-than or equal-to the provided values before continuing, then perform a workgroup barrier.

Emits: - amdgpu.memory_counter_wait with specified counters - rocdl.s.barrier for workgroup synchronization

ds: int | None = None
exp: int | None = None
property has_side_effects: bool
load: int | None = None
store: int | None = None
tkw_op_name: str = 'memory_counter_wait_barrier'
class wave_lang.kernel.ops.wave_ops.Min(arg: Node | list[Node], init: Node = None, dim: Any | None = None, block: bool | None = False)

NewSubclass(arg: ‘fx.Node | list[fx.Node]’, init: ‘fx.Node’ = None, dim: ‘Optional[Any]’ = None, block: ‘Optional[bool]’ = False)

tkw_op_name: str = 'min'
class wave_lang.kernel.ops.wave_ops.Minimum(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'minimum'
class wave_lang.kernel.ops.wave_ops.Mod(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'mod'
class wave_lang.kernel.ops.wave_ops.Mul(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'mul'
class wave_lang.kernel.ops.wave_ops.Ne(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'ne'
class wave_lang.kernel.ops.wave_ops.Neg(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'neg'
class wave_lang.kernel.ops.wave_ops.NestedRegionOp
static capture_source(node: Node | CustomOp) Node

Return the defining outer value for a local Placeholder node.

captured_vars(graph: Graph) list[Node]

Return local Placeholder nodes that represent captured outer values.

erase()

Erase the current node from the graph where it exists.

get_capture_bindings(graph: Graph, lookup: tuple[dict[Node, Node], list[Node]] | None = None) list[tuple[Node, Node]]

Return (outer_source, local_region_value) pairs in signature order.

get_captured_fx_node(graph: Graph, outer_node: Node, lookup: tuple[dict[Node, Node], list[Node]] | None = None) Node | None

Return the local representative for outer_node in graph if it exists.

get_root_graph()

Return the “root”/outermost layer of our computation graph. This is done by iteratively accessing parent_graph of current graph. This is done until we find the “root” graph who will have “subgraph” attribute.

classmethod handle(graph: RegionGraph, *args, **kwargs)

Base handle method for nested region operations. Extracts tag from kwargs and sets it on the node and underlying fx.Node.

infer_type(*args)

Infer the type of this operator using the types of its arguments.

iter_args(graph: Graph | None = None) list[Node]
classmethod materialize_capture_placeholder(graph: Graph, outer_node: Node | CustomOp, location: FileLineColInfo | StackTraceInfo | None = None) Node

Return a lifted placeholder that represents outer_node in graph.

If one already exists in the leading placeholder prefix, it is returned directly. Otherwise a new one is created.

outputs(graph: Graph | None = None) list[Node]
refresh_captures(graph: Graph, lookup: tuple[dict[Node, Node], list[Node]] | None = None) None

Refresh the capture signature from the current graph contents.

class wave_lang.kernel.ops.wave_ops.NewRegister(shape: 'tuple[IndexExpr, ...]', dtype: 'DataType', value: 'float')
dtype: DataType
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

shape: tuple[Expr, ...]
tkw_op_name: str = 'register'
value: float
class wave_lang.kernel.ops.wave_ops.NewScalar(value: 'float | IndexExpr', dtype: 'DataType')
dtype: DataType
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

tkw_op_name: str = 'scalar'
value: float | Expr
class wave_lang.kernel.ops.wave_ops.Or(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'or_'
class wave_lang.kernel.ops.wave_ops.Output(return_vals: Sequence[Any])

Represents an output node in the graph, representing the return value of a traced function.

add_to_graph(region_graph: RegionGraph, loc: FileLineColInfo | StackTraceInfo | None = None) Node
classmethod from_fx_node(node: Node) CustomOpT
property has_side_effects: bool
return_vals: Sequence[Any]
tkw_op_name: str = 'output'
property yielded_values: list[Any]

Yielded values as a list.

class wave_lang.kernel.ops.wave_ops.Permute(arg: Node, target_shape: Sequence[Expr])

Represents a permute operation that permutes arg into the target shape.

arg: Node
property indexing_dims: list[Expr]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

target_shape: Sequence[Expr]
tkw_op_name: str = 'permute'
transform_index(index: dict[Symbol, IndexSequence]) dict[Symbol, IndexSequence]

The permute operation swaps the strides of the permuted indices. So say we have a permute operation that swaps [B, M, N] to [M, N, B], then we swap the strides of the dimensions.

class wave_lang.kernel.ops.wave_ops.Placeholder(_name: str, _type: Type[DataType] | Type[KernelBuffer] | None = None)

Represents a placeholder node in the graph, i.e. an input to a function.

add_to_graph(region_graph: RegionGraph, loc: FileLineColInfo | StackTraceInfo | None = None) Node
custom_string(value_map: dict[str, str]) str
erase()

Erase the current node from the graph where it exists.

classmethod from_fx_node(node: Node) PlaceholderT
get_captured_fx_node() Node | None
property index: list[dict[Symbol, IndexSequence]]
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

tkw_op_name: str = 'placeholder'
class wave_lang.kernel.ops.wave_ops.Powf(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'powf'
class wave_lang.kernel.ops.wave_ops.Read(memory: 'fx.Proxy', elements_per_thread: 'Optional[Any]' = None, mapping: 'Optional[IndexMapping]' = None, mapping_dynamic_vals: 'tuple[fx.Node, ...]' = (), bounds: 'Optional[dict[IndexSymbol, IndexExpr]]' = None, flags: 'MemoryAccessFlags' = MemoryAccessFlags.NONE, source: 'Optional[tuple[IndexExpr]]' = None, target: 'Optional[tuple[IndexExpr]]' = None, _write_dependency: 'Optional[list[fx.Node]]' = None)
bounds: dict[Symbol, Expr] | None = None
property dtype: DataType
elements_per_thread: Any | None = None
flags: MemoryAccessFlags = 0
get_derived_indices() list[tuple[dict[Symbol, IndexSequence], Node]]
has_identity_mapping() bool

Check if mapping between input memory and output register is identity.

property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

is_contiguous_vec(constraints, target: str) bool

Check if op can be lowered to contiguous vector ops

If False we will have to lower it to gather

mapping: IndexMapping | None = None
mapping_dynamic_vals: tuple[Node, ...] = ()
memory: Proxy
property memory_type: Memory
source: tuple[Expr] | None = None
target: tuple[Expr] | None = None
tkw_op_name: str = 'read'
transform_index_backwards(index: dict[Symbol, IndexSequence], arg: Node) dict[Symbol, IndexSequence]

Propagate index backwards.

Dynamic values potentially can have non-identity mapping, so we need to update index when walking from the node to dyn val arguments.

E.g. if index is $idx and dynamic_val_mappings={N: j // ELEMS_PER_THREAD} resulted arg index will be $idx // ELEMS_PER_THREAD.

property write_dependency: Node
class wave_lang.kernel.ops.wave_ops.Reciprocal(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'reciprocal'
class wave_lang.kernel.ops.wave_ops.ReduceOp(arg: Node | list[Node], init: Node = None, dim: Any | None = None, block: bool | None = False)

Represents a Reduce computation.

arg: Source tensor/value to reduce init: init/accumulator for reduce dim: which symbolic dim to reduce. block: When set to true, reduce across block, else reduce across warp.

arg: Node | list[Node]
block: bool | None = False
dim: Any | None = None
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

init: Node = None
property reduction_dim: Symbol
class wave_lang.kernel.ops.wave_ops.Remf(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'remf'
class wave_lang.kernel.ops.wave_ops.Reshape(args: Node | Sequence[Node], target_vector_shape: dict[Symbol, int], logical_slice: int = 0, num_slices: int = 1)

Represents a reshape operation that reshapes vectors along the same dimension.

Conceptually, this either concatenates multiple vectors into a single vector or extracts slices from the vector. Since this operation appears after graph expansion, it never actually has multiple results: each expanded instance of this operation extracts a single slice.

args: Node | Sequence[Node]
property indexing_dims: list[Expr]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

logical_slice: int = 0
num_slices: int = 1
target_vector_shape: dict[Symbol, int]
tkw_op_name: str = 'reshape'
class wave_lang.kernel.ops.wave_ops.Round(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'round'
class wave_lang.kernel.ops.wave_ops.Roundeven(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'roundeven'
class wave_lang.kernel.ops.wave_ops.Rsqrt(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'rsqrt'
class wave_lang.kernel.ops.wave_ops.ScaledMMA(lhs: 'fx.Node', lhs_scale: 'fx.Node', rhs: 'fx.Node', rhs_scale: 'fx.Node', acc: 'fx.Node', mma_type: "Optional['ScaledMMAType']" = None)
acc: fx.Node
property acc_index: dict[Symbol, IndexSequence]
property acc_type: Memory
align_index(constraints: list[Constraint]) None
custom_string(value_map: dict[str, str]) str
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

lhs: fx.Node
property lhs_index: dict[Symbol, IndexSequence]
lhs_scale: fx.Node
property lhs_scale_index: dict[Symbol, IndexSequence]
property lhs_scale_type: Memory
property lhs_type: Memory
mma_type: 'ScaledMMAType' | None = None
operand_index(operand_map: dict[Symbol, int], shape: list[Expr]) dict[Symbol, IndexSequence]
property reduction_dim: Symbol
rhs: fx.Node
property rhs_index: dict[Symbol, IndexSequence]
rhs_scale: fx.Node
property rhs_scale_index: dict[Symbol, IndexSequence]
property rhs_scale_type: Memory
property rhs_type: Memory
tkw_op_name: str = 'scaled_mma'
class wave_lang.kernel.ops.wave_ops.ScanOp(arg: Node | list[Node], init: Node | None = None, dim: Symbol | None = None, block: bool | None = False)

Base class for all scan-style operations (e.g., cumsum).

arg: Source tensor/value to scan. init: Optional initial value. dim: Symbolic dimension along which to scan. block: When set to true, scan across block, else scan across warp.

arg: Node | list[Node]
block: bool | None = False
dim: Symbol | None = None
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

init: Node | None = None
property scan_dim: Symbol
class wave_lang.kernel.ops.wave_ops.ScatterAdd(register_src: Node, register_idx: Node, dim: Expr, memory: Node, mapping: IndexMapping, elements_per_thread: int | None = 1, bounds: dict[Symbol, Expr] | None = None)

ScatterAdd performs element-wise accumulation from a source register into shared memory (LDS), at locations determined by the index register along a specified dimension.

Limitations: - Only intra-workgroup scattering is supported (i.e., within shared memory / LDS), assuming a single wave. - Multi-wave execution is not guaranteed to be safe: synchronization issues may occur when threads write to the same index. Further investigation is needed. - The operation supports multiple elements per thread, assuming the non-scatter dimension is large enough (i.e., > elements_per_thread).

bounds: dict[Symbol, Expr] | None = None
dim: Expr
elements_per_thread: int | None = 1
property has_side_effects: bool
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

mapping: IndexMapping
memory: Node
property memory_type: Memory
register_idx: Node
property register_index: dict[Symbol, IndexSequence]
register_src: Node
property register_type: Register
tkw_op_name: str = 'scatter_add'
class wave_lang.kernel.ops.wave_ops.SchedulingBarrier(operations: list[Operation])

Represents a scheduling barrier in the graph. Takes in a list of operations that are allowed to cross the barrier.

operations: list[Operation]
tkw_op_name: str = 'scheduling_barrier'
class wave_lang.kernel.ops.wave_ops.SchedulingGroupBarrier(instructions: dict[Operation, int], sync_id: int)

Represents a scheduling group barrier in the graph. The scheduling group barrier defines scheduling groups. Each scheduling group contains different instructions in a specific order. The sync_id identifies scheduling groups that need to be aware of each other.

instructions: dict[Operation, int]
sync_id: int
tkw_op_name: str = 'scheduling_group_barrier'
class wave_lang.kernel.ops.wave_ops.SelectOp(cond: 'fx.Node', if_true: 'fx.Node', if_false: 'fx.Node')
cond: Node
if_false: Node
if_true: Node
property indexing_dims: list[Symbol]
infer_broadcast_shape(cond_type: Register, t_type: Register, f_type: Register)
infer_type(*args)

Infer the type of this operator using the types of its arguments.

tkw_op_name: str = 'select'
class wave_lang.kernel.ops.wave_ops.SelfIndex(dim: 'IndexSymbol', dtype: 'DataType', elements_per_thread: 'Optional[IndexExpr | int]' = None)
dim: Symbol
dtype: DataType
elements_per_thread: Expr | int | None = None
property indexing_dims: list[Symbol]
tkw_op_name: str = 'self_index'
property type: Register
class wave_lang.kernel.ops.wave_ops.SetSymbol(symbol: 'IndexExpr', register_: 'fx.Proxy')
property has_side_effects: bool
property indexing_dims: list[Symbol]
register_: Proxy
symbol: Expr
tkw_op_name: str = 'set_symbol'
property type: Register
class wave_lang.kernel.ops.wave_ops.SetWavePrio(priority: int)

An op that sets/tells hardware what level of priority certain instructions/region is. This is useful for ping-pong or general case where two Waves share the same SIMD, but we want to tell the SIMD to prioritize on wave or the other.

property has_side_effects: bool
priority: int
tkw_op_name: str = 'set_wave_prio'
class wave_lang.kernel.ops.wave_ops.SharedMemoryBarrier(wait_async_ops: bool = False, tensor_wait: bool = False)

Represents a shared memory barrier in the graph.

property has_side_effects: bool
tensor_wait: bool = False
tkw_op_name: str = 'shared_memory_barrier'
wait_async_ops: bool = False
class wave_lang.kernel.ops.wave_ops.SharedMemoryBarrierSignal(barId: int = 0, tensor_wait: bool = False, ds_wait: bool = True)

Represents a shared memory barrier signal in the graph. (gfx12) Argument specifies which barrier to signal. [1:16]: named barriers

0: NOOP

-1: works as s_barrier -2: trap barrier -3: cluster barrier

Parameters:
  • barId – The barrier ID to signal

  • tensor_wait – If True, emit s_wait_tensorcnt(0) before signaling

  • ds_wait – If True, emit s_wait_dscnt(0) before signaling (for non-cluster barriers)

barId: int = 0
ds_wait: bool = True
property has_side_effects: bool
tensor_wait: bool = False
tkw_op_name: str = 'shared_memory_barrier_signal'
class wave_lang.kernel.ops.wave_ops.SharedMemoryBarrierWait(barId: int = 0)

Wait for all waves in a WG to signal the barrier before proceeding. (gfx12) synchronize waves within a WG. Argument specifies which barrier to wait on. [1:16]: named barriers

0: NOOP

-1: works as s_barrier -2: trap barrier -3: cluster barrier

barId: int = 0
property has_side_effects: bool
tkw_op_name: str = 'shared_memory_barrier_wait'
class wave_lang.kernel.ops.wave_ops.ShuffleOp(arg: fx.Node, offset: int, width: int, mode: ShuffleMode)

Represents a shuffle.xor op.

arg: value/vector to shuffle. offset: xor offset. width: xor width.

arg: fx.Node
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

mode: ShuffleMode
offset: int
tkw_op_name: str = 'shuffle'
width: int
class wave_lang.kernel.ops.wave_ops.Sin(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'sin'
class wave_lang.kernel.ops.wave_ops.Sinh(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'sinh'
class wave_lang.kernel.ops.wave_ops.Softsign(arg: Node, logit_cap: float = 30.0, apply_scaling: bool = False, head_dim: int = None)

NewSubclass(arg: ‘fx.Node’, logit_cap: ‘float’ = 30.0, apply_scaling: ‘bool’ = False, head_dim: ‘int’ = None)

tkw_op_name: str = 'softsign'
class wave_lang.kernel.ops.wave_ops.SoftsignOp(arg: 'fx.Node', logit_cap: 'float' = 30.0, apply_scaling: 'bool' = False, head_dim: 'int' = None)
apply_scaling: bool = False
arg: Node
head_dim: int = None
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

logit_cap: float = 30.0
class wave_lang.kernel.ops.wave_ops.Sqrt(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'sqrt'
class wave_lang.kernel.ops.wave_ops.Sub(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'sub'
class wave_lang.kernel.ops.wave_ops.Sum(arg: Node | list[Node], init: Node = None, dim: Any | None = None, block: bool | None = False)

NewSubclass(arg: ‘fx.Node | list[fx.Node]’, init: ‘fx.Node’ = None, dim: ‘Optional[Any]’ = None, block: ‘Optional[bool]’ = False)

tkw_op_name: str = 'sum'
class wave_lang.kernel.ops.wave_ops.Tanh(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'tanh'
class wave_lang.kernel.ops.wave_ops.TanhApprox(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'tanh_approx'
class wave_lang.kernel.ops.wave_ops.TensorCounterWait(count: int = 0)

Wait for the tensor counter to reach the specified value. Generates rocdl.s.wait.tensorcnt instruction.

NOTE: This operation is only supported on gfx1250 targets.

count: int = 0
property has_side_effects: bool
tkw_op_name: str = 'tensor_counter_wait'
class wave_lang.kernel.ops.wave_ops.TensorLoadToLDS(src: 'list[Memory]', dst: 'list[Memory]', element_type: 'DataType', distributed_shape: 'dict[IndexSymbol, IndexExpr]', shared_tile_index: 'dict[IndexSymbol, IndexSequence]', global_tile_index: 'dict[IndexSymbol, IndexSequence]', bounds: 'dict[IndexSymbol, IndexExpr]', multicast_mask: 'Optional[IndexExpr]' = None, input_selector: 'IndexSymbol | int' = 0)
bounds: dict[Symbol, Expr]
distributed_shape: dict[Symbol, Expr]
dst: list[Memory]
element_type: DataType
global_tile_index: dict[Symbol, IndexSequence]
property has_side_effects: bool
input_selector: Symbol | int = 0
multicast_mask: Expr | None = None
shared_tile_index: dict[Symbol, IndexSequence]
src: list[Memory]
tkw_op_name: str = 'tensor_load_to_lds'
class wave_lang.kernel.ops.wave_ops.TopkOp(arg: Node, k_dim: Symbol, dim: Symbol)

Represents a TopK computation.

arg: Source tensor Register value to find top-k from k_dim: Dimension symbol representing the K size (number of top elements to select) dim: which symbolic dim to perform topk on. This dimension will be replaced by k_dim in the output.

arg: Node
dim: Symbol
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

property init
k_dim: Symbol
property reduction_dim: Symbol
tkw_op_name: str = 'topk'
class wave_lang.kernel.ops.wave_ops.Truediv(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'truediv'
class wave_lang.kernel.ops.wave_ops.UnaryPyOp(arg: Node)

Represents a unary python operator.

arg: Node
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

property py_operator: str
final class wave_lang.kernel.ops.wave_ops.Unknown(args: Sequence[Any], kwargs: dict[Any, Any])

Represents an fx.Node that has no corresponding CustomNode class.

args: Sequence[Any]
custom_string(value_map: dict[str, str]) str
classmethod from_fx_node(node: Node) Unknown
kwargs: dict[Any, Any]
class wave_lang.kernel.ops.wave_ops.WorkgroupBarrier

Represents a synchronization of all threads in a workgroup. Threads will wait on a WorkgroupBarrier until all the threads in the workgroup has called a WorkgroupBarrier(does not have to be in the same location).

property has_side_effects: bool
tkw_op_name: str = 'workgroup_barrier'
class wave_lang.kernel.ops.wave_ops.Write(register_: 'fx.Proxy', memory: 'fx.Proxy', elements_per_thread: 'Optional[Any]' = None, mapping: 'Optional[IndexMapping]' = None, mapping_dynamic_vals: 'tuple[fx.Node, ...]' = (), bounds: 'Optional[dict[IndexSymbol, IndexExpr]]' = None, flags: 'MemoryAccessFlags' = MemoryAccessFlags.NONE, source: 'Optional[tuple[IndexExpr]]' = None, target: 'Optional[tuple[IndexExpr]]' = None)
bounds: dict[Symbol, Expr] | None = None
elements_per_thread: Any | None = None
flags: MemoryAccessFlags = 0
get_derived_indices() list[tuple[dict[Symbol, IndexSequence], Node]]
has_identity_mapping() bool

Check if mapping between input register and output memory is identity.

property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

is_contiguous_vec(constraints, target: str) bool

Check if op can be lowered to contiguous vector ops

If False we will have to lower it to gather

mapping: IndexMapping | None = None
mapping_dynamic_vals: tuple[Node, ...] = ()
memory: Proxy
property memory_type: Memory
register_: Proxy
property register_index: dict[Symbol, IndexSequence]
property register_type: Register
source: tuple[Expr] | None = None
target: tuple[Expr] | None = None
tkw_op_name: str = 'write'
transform_index_backwards(index: dict[Symbol, IndexSequence], arg: Node) dict[Symbol, IndexSequence]

Propagate index backwards.

Dynamic values potentially can have non-identity mapping, so we need to update index when walking from the node to dyn val arguments.

E.g. if index is $idx and dynamic_val_mappings={N: j // ELEMS_PER_THREAD} resulted arg index will be $idx // ELEMS_PER_THREAD.

wave_lang.kernel.ops.wave_ops.abs(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.allocate(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.apply_expr(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.atan2(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.atomic_add(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.atomic_min(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.bitcast(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.bounds_check(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.broadcast(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.cast(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.cbrt(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.conditional(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.cos(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.cumsum(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.debug_log(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.define_interface_op(op_name: str) Callable[[T], T]

Generate new subclass for op handling, deriving from the base interface class. Generated subclass can be used for emitting the op from compiler/python side, by calling the generated subclass name.

The generated subclass name would be pascal case of the TKW op name. For example: “tkw.op_name” -> “OpName” “tkw.exp2” -> “Exp2”

wave_lang.kernel.ops.wave_ops.define_op(op_name: str) Callable[[T], T]
wave_lang.kernel.ops.wave_ops.define_py_op(py_op: Callable) Callable[[T], T]

Register python internal operators as custom ops. This overloads python operator specific functions such as __add__ of fx.Proxy with a handler in order to control the tracing of the operator and map it to a dynamically created sublclass of UnaryPyOp or BinaryPyOp.

wave_lang.kernel.ops.wave_ops.eq(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.exp(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.exp2(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.extract(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.extract_slice(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.gather_to_lds(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.ge(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.get_custom(node: Node) CustomOp

Get the corresponding CustomOp for a given fx.Node.

wave_lang.kernel.ops.wave_ops.get_result(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.get_shape_from_bindings(constraints: list[Constraint], target: tuple[Expr]) list[Expr]
wave_lang.kernel.ops.wave_ops.gt(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.has_same_custom_type(lhs_type: Memory, rhs_type: Memory) bool
wave_lang.kernel.ops.wave_ops.iterate(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.le(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.log10(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.log2(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.lt(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.max(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.maximum(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.memory_counter_wait(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.memory_counter_wait_barrier(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.min(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.minimum(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.mma(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.mod(lhs: Register, rhs: Register) Register
wave_lang.kernel.ops.wave_ops.ne(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.permute(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.powf(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.read(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.read_meets_hw_transpose_requirements(read: Read, constraints: list[Constraint], target: str) bool
wave_lang.kernel.ops.wave_ops.reciprocal(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.register(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.remf(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.reshape(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.round(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.roundeven(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.rsqrt(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.scalar(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.scaled_mma(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.scatter_add(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.scheduling_barrier(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.scheduling_group_barrier(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.select(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.self_index(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.set_symbol(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.set_wave_prio(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.shared_memory_barrier(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.shared_memory_barrier_signal(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.shared_memory_barrier_wait(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.shuffle(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.sin(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.sinh(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.softsign(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.sqrt(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.sum(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.tag(expr: Proxy, tag_name: str) Proxy

Assign a tag to the result of a Python expression (e.g., arithmetic operators).

This function allows tagging operations that don’t natively support the tag= keyword, such as Python arithmetic operators (*, -, +, /).

Usage:

# Instead of: result = a * b # Cannot tag this directly

# Use: result = tkw.tag(a * b, “multiply_ab”)

# The tag can then be used in wave_schedule: multiply_ops = tkw.get_node_by_tag(“multiply_ab”)

Parameters:
  • expr – The fx.Proxy result of an expression (e.g., a * b, x - y)

  • tag_name – The tag string to assign to the operation

Returns:

The same fx.Proxy, allowing chained expressions

wave_lang.kernel.ops.wave_ops.tanh(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.tanh_approx(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.tensor_counter_wait(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.tensor_load_to_lds(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.topk(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.workgroup_barrier(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.write(*args: Any, **kwargs: dict[str, Any])