Wave Ops¶
- class wave_lang.kernel.ops.wave_ops.Abs(arg: Node)¶
NewSubclass(arg: ‘fx.Node’)
- tkw_op_name: str = 'abs'¶
- class wave_lang.kernel.ops.wave_ops.Add(lhs: Any, rhs: Any)¶
NewSubclass(lhs: ‘Any’, rhs: ‘Any’)
- tkw_op_name: str = 'add'¶
- class wave_lang.kernel.ops.wave_ops.Allocate(shape: tuple[~sympy.core.expr.Expr], distributed_shape: tuple[~sympy.core.expr.Expr], dtype: ~wave_lang.kernel._support.dtype.DataType, address_space: ~wave_lang.kernel.lang.kernel_buffer.AddressSpace = $SHARED_ADDRESS_SPACE, padding: int = 0, parent: ~torch.fx.node.Node | None = None, offset: ~sympy.core.expr.Expr | None = None, tail_padding: int = 0)¶
Represents an allocation in an address space (such as shared memory).
- address_space: AddressSpace = $SHARED_ADDRESS_SPACE¶
- property allocation_size: Expr¶
Returns the full size of the allocation in bytes including all padding.
- distributed_shape: tuple[Expr]¶
- dtype: DataType¶
- property indexing_dims: list[Symbol]¶
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- offset: Expr | None = None¶
- padding: int = 0¶
- parent: Node | None = None¶
- shape: tuple[Expr]¶
- tail_padding: int = 0¶
- tkw_op_name: str = 'allocate'¶
- property type: Memory¶
- property unpadded_dims: dict[Symbol, Expr]¶
- property unpadded_shape: tuple[Expr]¶
- class wave_lang.kernel.ops.wave_ops.And(lhs: Any, rhs: Any)¶
NewSubclass(lhs: ‘Any’, rhs: ‘Any’)
- tkw_op_name: str = 'and_'¶
- class wave_lang.kernel.ops.wave_ops.ApplyExpr(register_: 'fx.Proxy | Sequence[fx.Proxy]', expr: 'Callable')¶
- expr: Callable¶
- property indexing_dims: list[Symbol]¶
- register_: Proxy | Sequence[Proxy]¶
- tkw_op_name: str = 'apply_expr'¶
- property type: Register¶
- class wave_lang.kernel.ops.wave_ops.Atan2(lhs: Any, rhs: Any)¶
NewSubclass(lhs: ‘Any’, rhs: ‘Any’)
- tkw_op_name: str = 'atan2'¶
- class wave_lang.kernel.ops.wave_ops.AtomicAddOp(lhs: 'Any', rhs: 'Any', elements_per_thread: 'Optional[Any]' = None, mapping: 'Optional[IndexMapping]' = None, mapping_dynamic_vals: 'tuple[fx.Node, ...]' = ())¶
- tkw_op_name: str = 'atomic_add'¶
- class wave_lang.kernel.ops.wave_ops.AtomicOp(lhs: Any, rhs: Any, elements_per_thread: Any | None = None, mapping: IndexMapping | None = None, mapping_dynamic_vals: tuple[Node, ...] = ())¶
Represents an atomic operation in the graph. Takes in Register and Memory as inputs and writes the modified value back on to the buffer. Mapping attribute maps the index from wave kernel to the shared memory index the wavegroup operates on.
- elements_per_thread: Any | None = None¶
- property has_side_effects: bool¶
- property indexing_dims: list[Symbol]¶
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- mapping: IndexMapping | None = None¶
- mapping_dynamic_vals: tuple[Node, ...] = ()¶
- property memory_type: Memory¶
- tkw_op_name: str = 'atomic_min'¶
- class wave_lang.kernel.ops.wave_ops.BinaryOpBase(lhs: Any, rhs: Any)¶
Represents an elementwise binary python operator.
DTYPE requirement: lhs and rhs needs to have the same dtpye. Shape requirement: lhs and rhs either have same shape or
their shape must be broadcastable to one another.
- property indexing_dims: list[Symbol]¶
- infer_shape() Any¶
- lhs: Any¶
- property py_operator: str¶
- rhs: Any¶
- class wave_lang.kernel.ops.wave_ops.BinaryPyOp(lhs: 'Any', rhs: 'Any')¶
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- class wave_lang.kernel.ops.wave_ops.BitcastOp(arg: Node, dtype: DataType)¶
Represents a bitcast operation.
- arg: Node¶
- dtype: DataType¶
- property indexing_dims: list[Symbol]¶
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- property scale_factor¶
- tkw_op_name: str = 'bitcast'¶
- class wave_lang.kernel.ops.wave_ops.BoundsCheck(index_exprs: dict[Expr, IndexSequence], bounds: dict[Symbol, Expr], mapping: IndexMapping | None = None, mapping_dynamic_vals: tuple[Node, ...] = (), mask_bounds: dict[Symbol, Expr] | None = None)¶
Checks if a dimension is within a bound.
- bounds: dict[Symbol, Expr]¶
- property has_side_effects: bool¶
- index_exprs: dict[Expr, IndexSequence]¶
- mapping: IndexMapping | None = None¶
- mapping_dynamic_vals: tuple[Node, ...] = ()¶
- mask_bounds: dict[Symbol, Expr] | None = None¶
- tkw_op_name: str = 'bounds_check'¶
- class wave_lang.kernel.ops.wave_ops.Broadcast(arg: Node, target_shape: Sequence[Symbol] = None)¶
Represents a Broadcast operation.
arg: Source tensor/value to broadcast target_shape: symbolic target broadcast shape.
- arg: Node¶
- property indexing_dims: list[Symbol]¶
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- target_shape: Sequence[Symbol] = None¶
- tkw_op_name: str = 'broadcast'¶
- class wave_lang.kernel.ops.wave_ops.CastOp(arg: Node, dtype: DataType)¶
Represents a cast operation.
- arg: Node¶
- dtype: DataType¶
- property indexing_dims: list[Symbol]¶
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- tkw_op_name: str = 'cast'¶
- class wave_lang.kernel.ops.wave_ops.Cbrt(arg: Node)¶
NewSubclass(arg: ‘fx.Node’)
- tkw_op_name: str = 'cbrt'¶
- class wave_lang.kernel.ops.wave_ops.ComparisonPyOp(lhs: 'Any', rhs: 'Any')¶
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- class wave_lang.kernel.ops.wave_ops.Conditional(condition: Proxy | Expr, subgraph_name: str, implicit_captures: Sequence[Proxy], else_return: Sequence[Register] | None = None)¶
The optional else_return argument must match the type of any return statements in the conditional body. The else_return are the default return value for the else branch.
- condition: Proxy | Expr¶
- else_return: Sequence[Register] | None = None¶
- implicit_captures: Sequence[Proxy]¶
- property indexing_dims: list[Symbol] | list[list[Symbol]]¶
- property init_args¶
Alias for else_return to match Iterate interface for shared processing code.
- subgraph_name: str¶
- tkw_op_name: str = 'conditional'¶
- class wave_lang.kernel.ops.wave_ops.Cos(arg: Node)¶
NewSubclass(arg: ‘fx.Node’)
- tkw_op_name: str = 'cos'¶
- class wave_lang.kernel.ops.wave_ops.Cumsum(arg: Node | list[Node], init: Node | None = None, dim: Symbol | None = None, block: bool | None = False)¶
NewSubclass(arg: ‘fx.Node | list[fx.Node]’, init: ‘Optional[fx.Node]’ = None, dim: ‘Optional[IndexSymbol]’ = None, block: ‘Optional[bool]’ = False)
- tkw_op_name: str = 'cumsum'¶
- class wave_lang.kernel.ops.wave_ops.CustomOp¶
Base class for all custom fx nodes.
Fields with compare=False are infrastructure or scheduling artifacts and do not participate in semantic equality used for trace equivalence.
- add_to_graph(region_graph: RegionGraph, type: Any = None, loc: FileLineColInfo | StackTraceInfo | None = None, tag: str | None = None) Node¶
- copy(new_name: str | None = None, new_graph: ~torch.fx.graph.Graph | None = None, arg_transform: ~typing.Callable[[~typing.Any], ~typing.Any] | None = <function CustomOp.<lambda>>, anchor: ~torch.fx.node.Node | None = None) Self¶
Returns a duplicate of this node.
- copy_core_attributes(new_node: Node)¶
Copy core attributes from the current node to the new node.
- classmethod create(graph: Graph, *args, type: Any = None, extra_attrs: dict[str, Any] | None = None, **kwargs) CustomOpT¶
Create a CustomOp instance and its FX node together.
This classmethod properly instantiates the CustomOp with its semantic fields, creates the FX node, and links them together. Useful when building FX graphs programmatically (e.g., from MLIR).
- Parameters:
graph – FX graph to add the node to.
*args – Positional arguments for the dataclass constructor.
type – Optional type to set on the node.
extra_attrs – Additional attributes to set on the fx_node after creation (e.g., index, vector_shapes). These are not passed to the dataclass constructor.
**kwargs – Keyword arguments forwarded to the dataclass constructor.
- Returns:
The created CustomOp instance with fx_node and graph fields populated.
Example
- register = NewRegister.create(
graph, dims, dtype, init_value, type=Register[(M, N, f32)], extra_attrs={“vector_shapes”: {M: 16, N: 16}},
)
- custom_string(value_map: dict[str, str]) str¶
- erase()¶
Erase the current node from the graph where it exists.
- property expanded_dims: dict[Symbol, int]¶
During expansion each node is expanded along its indexing dimensions. The expanded_dims property stores the dimensions along which the node has been expanded as well as the scaling along that dimension.
For example, a node with indexing dimensions [M, N] with dimensional scaling {M: 2, N: 2}, will be expanded to 4 nodes, with each expanded node mapping to the following expanded_dims {M: 0, N: 0}, {M: 0, N: 1}, {M: 1, N: 0}, {M: 1, N: 1}.
- classmethod from_fx_node(node: Node) CustomOpT¶
- fx_node: Node | None = None¶
- graph: Graph | None = None¶
- classmethod handle(graph: RegionGraph, *args, **kwargs) Node¶
- property has_side_effects: bool¶
- property index: dict[Symbol, IndexSequence] | None¶
- property indexing_dims: list[Symbol]¶
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- property location: FileLineColInfo | StackTraceInfo | None¶
- property name: str¶
- property node_args: dict[int, Any]¶
Returns the args to this custom op using subclasses of CustomOp if possible.
- property pre_expansion_id: int¶
- replace_all_uses_with(new_node: CustomOp | Node)¶
Replace all uses of the current node with the new node.
- replace_all_uses_with_except(new_node: CustomOp | Node, except_nodes: list[CustomOp])¶
Replace all uses of the current node with the new node except for the nodes in except_nodes.
- replace_uses_with(new_node: CustomOp | Node, *, graph: Graph | None = None, propagate_location: bool = True) None¶
Replace uses of this node, optionally restricted to one graph.
When graph is None, this replaces every use of the current node. When graph is provided, only uses in that specific fx.Graph are replaced.
This matters for nested regions: the same FX node may be referenced from multiple graphs at the same time, for example from a nested subgraph and from sibling or outer graphs. A graph-scoped replacement therefore does not imply that the current node becomes globally dead. After the call, the node may still have uses in other graphs that were intentionally left untouched.
- replacement_location_propagate(new_node: CustomOp | Node | list[Node])¶
Set the new_node location if it doesn’t have one.
- property rrt¶
- property scheduling_parameters¶
- property tag: set[str] | None¶
- tkw_op_name: str = 'unknown'¶
- transform_index(index: dict[Symbol, IndexSequence]) dict[Symbol, IndexSequence]¶
Transform the index of the node based on the provided mapping.
- transform_index_backwards(index: dict[Symbol, IndexSequence], arg: Node) dict[Symbol, IndexSequence]¶
Transform the index of the node when propagating index backwards, i.e. from node to its arguments.
- property type: Any¶
- property unroll_iteration: int | None¶
Returns the unroll iteration index for this node, or None if not unrolled.
- update_arg(idx_or_name: int | str | Node, value: CustomOp | Node)¶
Update an operand or named field while keeping the underlying fx.Node consistent for both positional arguments and keyword arguments.
- property users: list[Any]¶
Returns the users of this custom op using subclasses of CustomOp if possible.
- property vector_shapes: dict[Symbol, int]¶
- class wave_lang.kernel.ops.wave_ops.DebugLog(register_: ~torch.fx.proxy.Proxy, label: str | None = None, extra_iteration_dimensions: list[tuple[~sympy.core.symbol.Symbol, ~sympy.core.symbol.Symbol, int]] | None = None, mapping: ~wave_lang.kernel.lang.wave_types.IndexMapping | None = None, mapping_dynamic_vals: tuple[~torch.fx.node.Node, ...] = (), printer: ~typing.Callable[[str, ~typing.Any], ~typing.Any] | None = <built-in function print>, handler: ~typing.Callable[[dict[str, ~typing.Any]], ~typing.Any] | None = None)¶
An op for debugging. Represents a write to an implicit global memory location. The kernel will implicitly have an extra memory input added that will be injected by the Python kernel launcher. The memory can be accessed by passing an an extra keyword to kernel invokation “debug_logs” with an empty dictionary. The dictionary will be mutated, adding the label`s as keys that map to nested dictionaries. The nested dictionaries have a `value field with the log tensor, and other keys (eg. symbolic_shape).
IE the debug_logs dictionary will look like: {LABEL: {“value”: LOG_TENSOR, (other metadata keys) …}}
Note that the logs collected in the debug_logs field, or handled by printer or handler represent a global view of the log after all writes, not limited to any one wave or loop iteration.
The optional printer argument should be a function that accepts a string (the log’s label) and the value of the log itself (a Torch tensor). The default value for this is print, though note that it will probably print an abbreviated view of the global tensor.
The optional handler argument should be a function that accepts the whole debug_logs object (IE all logs, not just one). The handler function gives a way to specify something like a viewer for all logs, but specify it inline among print functions rather than separately.
The optional extra_iteration_dimensions argument allows you to add extra dimensions to capture values from multiple iterations of a loop. It takes a list of tuples, where each tuple contains (dimension_name, iteration_axis, max_iterations). The dimension_name must be a unique symbol, and will be the name of the dimension in the symbolic shape of the output. The iteration_axis must be the axis of an Iterate operation, IE the dimension being reduced in the iteration. The max_iterations argument is a positive integer that represents the maximum number of iterations to store. Note that if you give too few, later iterations will currently overwrite the final iteration slot.
The API and semantics of this operation are not yet stable, but since it is just a debugging tool, you want to take any debug logging out of your kernel before shipping it anyway.
- extra_iteration_dimensions: list[tuple[Symbol, Symbol, int]] | None = None¶
- handler: Callable[[dict[str, Any]], Any] | None = None¶
- property has_side_effects: bool¶
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- label: str | None = None¶
- mapping: IndexMapping | None = None¶
- mapping_dynamic_vals: tuple[Node, ...] = ()¶
- property memory: Proxy | None¶
- printer(*, sep=' ', end='\n', file=None, flush=False)¶
Prints the values to a stream, or to sys.stdout by default.
- sep
string inserted between values, default a space.
- end
string appended after the last value, default a newline.
- file
a file-like object (stream); defaults to the current sys.stdout.
- flush
whether to forcibly flush the stream.
- register_: Proxy¶
- tkw_op_name: str = 'debug_log'¶
- class wave_lang.kernel.ops.wave_ops.Eq(lhs: Any, rhs: Any)¶
NewSubclass(lhs: ‘Any’, rhs: ‘Any’)
- tkw_op_name: str = 'eq'¶
- class wave_lang.kernel.ops.wave_ops.Exp(arg: Node)¶
NewSubclass(arg: ‘fx.Node’)
- tkw_op_name: str = 'exp'¶
- class wave_lang.kernel.ops.wave_ops.Exp2(arg: Node)¶
NewSubclass(arg: ‘fx.Node’)
- tkw_op_name: str = 'exp2'¶
- class wave_lang.kernel.ops.wave_ops.Extract(register_: Proxy, offset: Expr | int)¶
Op Rationale:
Extract is an op used to represent extracting of a scalar from TKW’s 1-D vector on the specified index.
This can also be viewed as indexing/slicing on the fastest dimension. Hence, the semantic of this op is designed to see itself as a reduction on the indexed/fastest dimension.
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- offset: Expr | int¶
- register_: Proxy¶
- tkw_op_name: str = 'extract'¶
- class wave_lang.kernel.ops.wave_ops.ExtractSlice(register_: 'fx.Proxy', offset: 'tuple[IndexExpr]', size: 'tuple[IndexExpr]', stride: 'tuple[IndexExpr]')¶
- offset: tuple[Expr]¶
- property rank: int¶
- register_: Proxy¶
- size: tuple[Expr]¶
- stride: tuple[Expr]¶
- tkw_op_name: str = 'extract_slice'¶
- property type: Register¶
- class wave_lang.kernel.ops.wave_ops.Floordiv(lhs: Any, rhs: Any)¶
NewSubclass(lhs: ‘Any’, rhs: ‘Any’)
- tkw_op_name: str = 'floordiv'¶
- class wave_lang.kernel.ops.wave_ops.GatherToLDS(src: Memory, dst: Memory, src_index: dict[Symbol, IndexSequence], dst_index: dict[Symbol, IndexSequence], dtype: DataType, elements_per_thread: Expr | int | None, src_mapping: IndexMapping | None, dst_mapping: IndexMapping | None, src_bounds: dict[Symbol, Expr] | None, src_mapping_dynamic_vals: tuple[Node, ...] = (), dst_mapping_dynamic_vals: tuple[Node, ...] = ())¶
Represents an instruction that performs direct load from global to lds. Source node points to the global memory to load from and the destination node points to shared memory.
- dst: Memory¶
- dst_index: dict[Symbol, IndexSequence]¶
- dst_mapping: IndexMapping | None¶
- dst_mapping_dynamic_vals: tuple[Node, ...] = ()¶
- dtype: DataType¶
- elements_per_thread: Expr | int | None¶
- property has_side_effects: bool¶
- src: Memory¶
- src_bounds: dict[Symbol, Expr] | None¶
- src_index: dict[Symbol, IndexSequence]¶
- src_mapping: IndexMapping | None¶
- src_mapping_dynamic_vals: tuple[Node, ...] = ()¶
- tkw_op_name: str = 'gather_to_lds'¶
- class wave_lang.kernel.ops.wave_ops.Ge(lhs: Any, rhs: Any)¶
NewSubclass(lhs: ‘Any’, rhs: ‘Any’)
- tkw_op_name: str = 'ge'¶
- class wave_lang.kernel.ops.wave_ops.GetResult(value: 'fx.Node', res_idx: 'int')¶
- property distributed_shape¶
- property index: dict[Symbol, IndexSequence]¶
- property indexing_dims: list[Expr]¶
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- res_idx: int¶
- tkw_op_name: str = 'get_result'¶
- value: Node¶
- class wave_lang.kernel.ops.wave_ops.Gt(lhs: Any, rhs: Any)¶
NewSubclass(lhs: ‘Any’, rhs: ‘Any’)
- tkw_op_name: str = 'gt'¶
- class wave_lang.kernel.ops.wave_ops.Invert(arg: Node)¶
NewSubclass(arg: ‘fx.Node’)
- tkw_op_name: str = 'invert'¶
- class wave_lang.kernel.ops.wave_ops.IterArg(_name: str, _type: Type[DataType] | Type[KernelBuffer] | None = None)¶
Represents a specific placeholder node in the graph that is an iter arg of a reduction node. IterArgs can be of type Register or Memory with a Shared memory address space.
- property distributed_shape¶
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- property iter_idx¶
- parent_op()¶
- class wave_lang.kernel.ops.wave_ops.Iterate(axis: 'IndexSymbol', init_args: 'Sequence[Any]', subgraph_name: 'str', implicit_captures: 'Sequence[fx.Proxy]', step: 'int' = 1, start: 'Optional[IndexExpr]' = None, condition: 'Optional[IndexExpr]' = None)¶
- axis: Symbol¶
- condition: Expr | None = None¶
- property count: int | None¶
- implicit_captures: Sequence[Proxy]¶
- property index: list[dict[Symbol, IndexSequence] | None]¶
Collect indices from the subgraph’s output return values.
Always returns a list with one entry per return value. Uses getattr with a None default so this is safe to call on partially-populated graphs where indices have not been propagated yet.
- property indexing_dims: list[Symbol] | list[list[Symbol]]¶
- init_args: Sequence[Any]¶
- start: Expr | None = None¶
- step: int = 1¶
- subgraph_name: str¶
- tkw_op_name: str = 'iterate'¶
- class wave_lang.kernel.ops.wave_ops.Le(lhs: Any, rhs: Any)¶
NewSubclass(lhs: ‘Any’, rhs: ‘Any’)
- tkw_op_name: str = 'le'¶
- class wave_lang.kernel.ops.wave_ops.Log10(arg: Node)¶
NewSubclass(arg: ‘fx.Node’)
- tkw_op_name: str = 'log10'¶
- class wave_lang.kernel.ops.wave_ops.Log2(arg: Node)¶
NewSubclass(arg: ‘fx.Node’)
- tkw_op_name: str = 'log2'¶
- class wave_lang.kernel.ops.wave_ops.Lt(lhs: Any, rhs: Any)¶
NewSubclass(lhs: ‘Any’, rhs: ‘Any’)
- tkw_op_name: str = 'lt'¶
- class wave_lang.kernel.ops.wave_ops.MMA(lhs: 'fx.Node', rhs: 'fx.Node', acc: 'fx.Node', mma_type: "Optional['MMAType'] | 'GenericDot'" = None)¶
- acc: fx.Node¶
- property acc_index: dict[Symbol, IndexSequence]¶
- property acc_type: Memory¶
- custom_string(value_map: dict[str, str]) str¶
- property indexing_dims: list[Symbol]¶
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- lhs: fx.Node¶
- property lhs_index: dict[Symbol, IndexSequence]¶
- property lhs_type: Memory¶
- mma_type: 'MMAType' | None | 'GenericDot' = None¶
- operand_index(operand_map: dict[Symbol, int], shape: list[Expr]) dict[Symbol, IndexSequence]¶
- property reduction_dim: Symbol¶
- rhs: fx.Node¶
- property rhs_index: dict[Symbol, IndexSequence]¶
- property rhs_type: Memory¶
- tkw_op_name: str = 'mma'¶
- class wave_lang.kernel.ops.wave_ops.MMABase¶
- class wave_lang.kernel.ops.wave_ops.Max(arg: Node | list[Node], init: Node = None, dim: Any | None = None, block: bool | None = False)¶
NewSubclass(arg: ‘fx.Node | list[fx.Node]’, init: ‘fx.Node’ = None, dim: ‘Optional[Any]’ = None, block: ‘Optional[bool]’ = False)
- tkw_op_name: str = 'max'¶
- class wave_lang.kernel.ops.wave_ops.Maximum(lhs: Any, rhs: Any)¶
NewSubclass(lhs: ‘Any’, rhs: ‘Any’)
- tkw_op_name: str = 'maximum'¶
- class wave_lang.kernel.ops.wave_ops.MemoryAccessFlags(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)¶
Flags for memory access operations (read/write). Maps to LLVM load/store attributes.
i.e. flags = MemoryAccessFlags.VOLATILE | MemoryAccessFlags.NONTEMPORAL
- NONE = 0¶
- NONTEMPORAL = 2¶
- VOLATILE = 1¶
- class wave_lang.kernel.ops.wave_ops.MemoryCounterWait(load: int | None = None, store: int | None = None, ds: int | None = None, exp: int | None = None)¶
Wait for the specified counters to be less-than or equal-to the provided values before continuing.
Emits: amdgpu.memory_counter_wait with specified counters
- ds: int | None = None¶
- exp: int | None = None¶
- property has_side_effects: bool¶
- load: int | None = None¶
- store: int | None = None¶
- tkw_op_name: str = 'memory_counter_wait'¶
- class wave_lang.kernel.ops.wave_ops.MemoryCounterWaitBarrier(load: int | None = None, store: int | None = None, ds: int | None = None, exp: int | None = None)¶
Wait for the specified counters to be less-than or equal-to the provided values before continuing, then perform a workgroup barrier.
Emits: - amdgpu.memory_counter_wait with specified counters - rocdl.s.barrier for workgroup synchronization
- ds: int | None = None¶
- exp: int | None = None¶
- property has_side_effects: bool¶
- load: int | None = None¶
- store: int | None = None¶
- tkw_op_name: str = 'memory_counter_wait_barrier'¶
- class wave_lang.kernel.ops.wave_ops.Min(arg: Node | list[Node], init: Node = None, dim: Any | None = None, block: bool | None = False)¶
NewSubclass(arg: ‘fx.Node | list[fx.Node]’, init: ‘fx.Node’ = None, dim: ‘Optional[Any]’ = None, block: ‘Optional[bool]’ = False)
- tkw_op_name: str = 'min'¶
- class wave_lang.kernel.ops.wave_ops.Minimum(lhs: Any, rhs: Any)¶
NewSubclass(lhs: ‘Any’, rhs: ‘Any’)
- tkw_op_name: str = 'minimum'¶
- class wave_lang.kernel.ops.wave_ops.Mod(lhs: Any, rhs: Any)¶
NewSubclass(lhs: ‘Any’, rhs: ‘Any’)
- tkw_op_name: str = 'mod'¶
- class wave_lang.kernel.ops.wave_ops.Mul(lhs: Any, rhs: Any)¶
NewSubclass(lhs: ‘Any’, rhs: ‘Any’)
- tkw_op_name: str = 'mul'¶
- class wave_lang.kernel.ops.wave_ops.Ne(lhs: Any, rhs: Any)¶
NewSubclass(lhs: ‘Any’, rhs: ‘Any’)
- tkw_op_name: str = 'ne'¶
- class wave_lang.kernel.ops.wave_ops.Neg(arg: Node)¶
NewSubclass(arg: ‘fx.Node’)
- tkw_op_name: str = 'neg'¶
- class wave_lang.kernel.ops.wave_ops.NestedRegionOp¶
- static capture_source(node: Node | CustomOp) Node¶
Return the defining outer value for a local Placeholder node.
- captured_vars(graph: Graph) list[Node]¶
Return local Placeholder nodes that represent captured outer values.
- erase()¶
Erase the current node from the graph where it exists.
- get_capture_bindings(graph: Graph, lookup: tuple[dict[Node, Node], list[Node]] | None = None) list[tuple[Node, Node]]¶
Return (outer_source, local_region_value) pairs in signature order.
- get_captured_fx_node(graph: Graph, outer_node: Node, lookup: tuple[dict[Node, Node], list[Node]] | None = None) Node | None¶
Return the local representative for outer_node in graph if it exists.
- get_root_graph()¶
Return the “root”/outermost layer of our computation graph. This is done by iteratively accessing parent_graph of current graph. This is done until we find the “root” graph who will have “subgraph” attribute.
- classmethod handle(graph: RegionGraph, *args, **kwargs)¶
Base handle method for nested region operations. Extracts tag from kwargs and sets it on the node and underlying fx.Node.
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- iter_args(graph: Graph | None = None) list[Node]¶
- classmethod materialize_capture_placeholder(graph: Graph, outer_node: Node | CustomOp, location: FileLineColInfo | StackTraceInfo | None = None) Node¶
Return a lifted placeholder that represents outer_node in graph.
If one already exists in the leading placeholder prefix, it is returned directly. Otherwise a new one is created.
- outputs(graph: Graph | None = None) list[Node]¶
- refresh_captures(graph: Graph, lookup: tuple[dict[Node, Node], list[Node]] | None = None) None¶
Refresh the capture signature from the current graph contents.
- class wave_lang.kernel.ops.wave_ops.NewRegister(shape: 'tuple[IndexExpr, ...]', dtype: 'DataType', value: 'float')¶
- dtype: DataType¶
- property indexing_dims: list[Symbol]¶
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- shape: tuple[Expr, ...]¶
- tkw_op_name: str = 'register'¶
- value: float¶
- class wave_lang.kernel.ops.wave_ops.NewScalar(value: 'float | IndexExpr', dtype: 'DataType')¶
- dtype: DataType¶
- property indexing_dims: list[Symbol]¶
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- tkw_op_name: str = 'scalar'¶
- value: float | Expr¶
- class wave_lang.kernel.ops.wave_ops.Or(lhs: Any, rhs: Any)¶
NewSubclass(lhs: ‘Any’, rhs: ‘Any’)
- tkw_op_name: str = 'or_'¶
- class wave_lang.kernel.ops.wave_ops.Output(return_vals: Sequence[Any])¶
Represents an output node in the graph, representing the return value of a traced function.
- add_to_graph(region_graph: RegionGraph, loc: FileLineColInfo | StackTraceInfo | None = None) Node¶
- classmethod from_fx_node(node: Node) CustomOpT¶
- property has_side_effects: bool¶
- return_vals: Sequence[Any]¶
- tkw_op_name: str = 'output'¶
- property yielded_values: list[Any]¶
Yielded values as a list.
- class wave_lang.kernel.ops.wave_ops.Permute(arg: Node, target_shape: Sequence[Expr])¶
Represents a permute operation that permutes arg into the target shape.
- arg: Node¶
- property indexing_dims: list[Expr]¶
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- target_shape: Sequence[Expr]¶
- tkw_op_name: str = 'permute'¶
- transform_index(index: dict[Symbol, IndexSequence]) dict[Symbol, IndexSequence]¶
The permute operation swaps the strides of the permuted indices. So say we have a permute operation that swaps [B, M, N] to [M, N, B], then we swap the strides of the dimensions.
- class wave_lang.kernel.ops.wave_ops.Placeholder(_name: str, _type: Type[DataType] | Type[KernelBuffer] | None = None)¶
Represents a placeholder node in the graph, i.e. an input to a function.
- add_to_graph(region_graph: RegionGraph, loc: FileLineColInfo | StackTraceInfo | None = None) Node¶
- custom_string(value_map: dict[str, str]) str¶
- erase()¶
Erase the current node from the graph where it exists.
- classmethod from_fx_node(node: Node) PlaceholderT¶
- get_captured_fx_node() Node | None¶
- property index: list[dict[Symbol, IndexSequence]]¶
- property indexing_dims: list[Symbol]¶
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- tkw_op_name: str = 'placeholder'¶
- class wave_lang.kernel.ops.wave_ops.Powf(lhs: Any, rhs: Any)¶
NewSubclass(lhs: ‘Any’, rhs: ‘Any’)
- tkw_op_name: str = 'powf'¶
- class wave_lang.kernel.ops.wave_ops.Read(memory: 'fx.Proxy', elements_per_thread: 'Optional[Any]' = None, mapping: 'Optional[IndexMapping]' = None, mapping_dynamic_vals: 'tuple[fx.Node, ...]' = (), bounds: 'Optional[dict[IndexSymbol, IndexExpr]]' = None, flags: 'MemoryAccessFlags' = MemoryAccessFlags.NONE, source: 'Optional[tuple[IndexExpr]]' = None, target: 'Optional[tuple[IndexExpr]]' = None, _write_dependency: 'Optional[list[fx.Node]]' = None)¶
- bounds: dict[Symbol, Expr] | None = None¶
- property dtype: DataType¶
- elements_per_thread: Any | None = None¶
- flags: MemoryAccessFlags = 0¶
- get_derived_indices() list[tuple[dict[Symbol, IndexSequence], Node]]¶
- has_identity_mapping() bool¶
Check if mapping between input memory and output register is identity.
- property indexing_dims: list[Symbol]¶
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- is_contiguous_vec(constraints, target: str) bool¶
Check if op can be lowered to contiguous vector ops
If False we will have to lower it to gather
- mapping: IndexMapping | None = None¶
- mapping_dynamic_vals: tuple[Node, ...] = ()¶
- memory: Proxy¶
- property memory_type: Memory¶
- source: tuple[Expr] | None = None¶
- target: tuple[Expr] | None = None¶
- tkw_op_name: str = 'read'¶
- transform_index_backwards(index: dict[Symbol, IndexSequence], arg: Node) dict[Symbol, IndexSequence]¶
Propagate index backwards.
Dynamic values potentially can have non-identity mapping, so we need to update index when walking from the node to dyn val arguments.
E.g. if index is $idx and dynamic_val_mappings={N: j // ELEMS_PER_THREAD} resulted arg index will be $idx // ELEMS_PER_THREAD.
- property write_dependency: Node¶
- class wave_lang.kernel.ops.wave_ops.Reciprocal(arg: Node)¶
NewSubclass(arg: ‘fx.Node’)
- tkw_op_name: str = 'reciprocal'¶
- class wave_lang.kernel.ops.wave_ops.ReduceOp(arg: Node | list[Node], init: Node = None, dim: Any | None = None, block: bool | None = False)¶
Represents a Reduce computation.
arg: Source tensor/value to reduce init: init/accumulator for reduce dim: which symbolic dim to reduce. block: When set to true, reduce across block, else reduce across warp.
- arg: Node | list[Node]¶
- block: bool | None = False¶
- dim: Any | None = None¶
- property indexing_dims: list[Symbol]¶
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- init: Node = None¶
- property reduction_dim: Symbol¶
- class wave_lang.kernel.ops.wave_ops.Remf(lhs: Any, rhs: Any)¶
NewSubclass(lhs: ‘Any’, rhs: ‘Any’)
- tkw_op_name: str = 'remf'¶
- class wave_lang.kernel.ops.wave_ops.Reshape(args: Node | Sequence[Node], target_vector_shape: dict[Symbol, int], logical_slice: int = 0, num_slices: int = 1)¶
Represents a reshape operation that reshapes vectors along the same dimension.
Conceptually, this either concatenates multiple vectors into a single vector or extracts slices from the vector. Since this operation appears after graph expansion, it never actually has multiple results: each expanded instance of this operation extracts a single slice.
- args: Node | Sequence[Node]¶
- property indexing_dims: list[Expr]¶
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- logical_slice: int = 0¶
- num_slices: int = 1¶
- target_vector_shape: dict[Symbol, int]¶
- tkw_op_name: str = 'reshape'¶
- class wave_lang.kernel.ops.wave_ops.Round(arg: Node)¶
NewSubclass(arg: ‘fx.Node’)
- tkw_op_name: str = 'round'¶
- class wave_lang.kernel.ops.wave_ops.Roundeven(arg: Node)¶
NewSubclass(arg: ‘fx.Node’)
- tkw_op_name: str = 'roundeven'¶
- class wave_lang.kernel.ops.wave_ops.Rsqrt(arg: Node)¶
NewSubclass(arg: ‘fx.Node’)
- tkw_op_name: str = 'rsqrt'¶
- class wave_lang.kernel.ops.wave_ops.ScaledMMA(lhs: 'fx.Node', lhs_scale: 'fx.Node', rhs: 'fx.Node', rhs_scale: 'fx.Node', acc: 'fx.Node', mma_type: "Optional['ScaledMMAType']" = None)¶
- acc: fx.Node¶
- property acc_index: dict[Symbol, IndexSequence]¶
- property acc_type: Memory¶
- align_index(constraints: list[Constraint]) None¶
- custom_string(value_map: dict[str, str]) str¶
- property indexing_dims: list[Symbol]¶
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- lhs: fx.Node¶
- property lhs_index: dict[Symbol, IndexSequence]¶
- lhs_scale: fx.Node¶
- property lhs_scale_index: dict[Symbol, IndexSequence]¶
- property lhs_scale_type: Memory¶
- property lhs_type: Memory¶
- mma_type: 'ScaledMMAType' | None = None¶
- operand_index(operand_map: dict[Symbol, int], shape: list[Expr]) dict[Symbol, IndexSequence]¶
- property reduction_dim: Symbol¶
- rhs: fx.Node¶
- property rhs_index: dict[Symbol, IndexSequence]¶
- rhs_scale: fx.Node¶
- property rhs_scale_index: dict[Symbol, IndexSequence]¶
- property rhs_scale_type: Memory¶
- property rhs_type: Memory¶
- tkw_op_name: str = 'scaled_mma'¶
- class wave_lang.kernel.ops.wave_ops.ScanOp(arg: Node | list[Node], init: Node | None = None, dim: Symbol | None = None, block: bool | None = False)¶
Base class for all scan-style operations (e.g., cumsum).
arg: Source tensor/value to scan. init: Optional initial value. dim: Symbolic dimension along which to scan. block: When set to true, scan across block, else scan across warp.
- arg: Node | list[Node]¶
- block: bool | None = False¶
- dim: Symbol | None = None¶
- property indexing_dims: list[Symbol]¶
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- init: Node | None = None¶
- property scan_dim: Symbol¶
- class wave_lang.kernel.ops.wave_ops.ScatterAdd(register_src: Node, register_idx: Node, dim: Expr, memory: Node, mapping: IndexMapping, elements_per_thread: int | None = 1, bounds: dict[Symbol, Expr] | None = None)¶
ScatterAdd performs element-wise accumulation from a source register into shared memory (LDS), at locations determined by the index register along a specified dimension.
Limitations: - Only intra-workgroup scattering is supported (i.e., within shared memory / LDS), assuming a single wave. - Multi-wave execution is not guaranteed to be safe: synchronization issues may occur when threads write to the same index. Further investigation is needed. - The operation supports multiple elements per thread, assuming the non-scatter dimension is large enough (i.e., > elements_per_thread).
- bounds: dict[Symbol, Expr] | None = None¶
- dim: Expr¶
- elements_per_thread: int | None = 1¶
- property has_side_effects: bool¶
- property indexing_dims: list[Symbol]¶
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- mapping: IndexMapping¶
- memory: Node¶
- property memory_type: Memory¶
- register_idx: Node¶
- property register_index: dict[Symbol, IndexSequence]¶
- register_src: Node¶
- property register_type: Register¶
- tkw_op_name: str = 'scatter_add'¶
- class wave_lang.kernel.ops.wave_ops.SchedulingBarrier(operations: list[Operation])¶
Represents a scheduling barrier in the graph. Takes in a list of operations that are allowed to cross the barrier.
- operations: list[Operation]¶
- tkw_op_name: str = 'scheduling_barrier'¶
- class wave_lang.kernel.ops.wave_ops.SchedulingGroupBarrier(instructions: dict[Operation, int], sync_id: int)¶
Represents a scheduling group barrier in the graph. The scheduling group barrier defines scheduling groups. Each scheduling group contains different instructions in a specific order. The sync_id identifies scheduling groups that need to be aware of each other.
- instructions: dict[Operation, int]¶
- sync_id: int¶
- tkw_op_name: str = 'scheduling_group_barrier'¶
- class wave_lang.kernel.ops.wave_ops.SelectOp(cond: 'fx.Node', if_true: 'fx.Node', if_false: 'fx.Node')¶
- cond: Node¶
- if_false: Node¶
- if_true: Node¶
- property indexing_dims: list[Symbol]¶
- infer_broadcast_shape(cond_type: Register, t_type: Register, f_type: Register)¶
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- tkw_op_name: str = 'select'¶
- class wave_lang.kernel.ops.wave_ops.SelfIndex(dim: 'IndexSymbol', dtype: 'DataType', elements_per_thread: 'Optional[IndexExpr | int]' = None)¶
- dim: Symbol¶
- dtype: DataType¶
- elements_per_thread: Expr | int | None = None¶
- property indexing_dims: list[Symbol]¶
- tkw_op_name: str = 'self_index'¶
- property type: Register¶
- class wave_lang.kernel.ops.wave_ops.SetSymbol(symbol: 'IndexExpr', register_: 'fx.Proxy')¶
- property has_side_effects: bool¶
- property indexing_dims: list[Symbol]¶
- register_: Proxy¶
- symbol: Expr¶
- tkw_op_name: str = 'set_symbol'¶
- property type: Register¶
- class wave_lang.kernel.ops.wave_ops.SetWavePrio(priority: int)¶
An op that sets/tells hardware what level of priority certain instructions/region is. This is useful for ping-pong or general case where two Waves share the same SIMD, but we want to tell the SIMD to prioritize on wave or the other.
- property has_side_effects: bool¶
- priority: int¶
- tkw_op_name: str = 'set_wave_prio'¶
Represents a shared memory barrier in the graph.
Represents a shared memory barrier signal in the graph. (gfx12) Argument specifies which barrier to signal. [1:16]: named barriers
0: NOOP
-1: works as s_barrier -2: trap barrier -3: cluster barrier
- Parameters:
barId – The barrier ID to signal
tensor_wait – If True, emit s_wait_tensorcnt(0) before signaling
ds_wait – If True, emit s_wait_dscnt(0) before signaling (for non-cluster barriers)
Wait for all waves in a WG to signal the barrier before proceeding. (gfx12) synchronize waves within a WG. Argument specifies which barrier to wait on. [1:16]: named barriers
0: NOOP
-1: works as s_barrier -2: trap barrier -3: cluster barrier
- class wave_lang.kernel.ops.wave_ops.ShuffleOp(arg: fx.Node, offset: int, width: int, mode: ShuffleMode)¶
Represents a shuffle.xor op.
arg: value/vector to shuffle. offset: xor offset. width: xor width.
- arg: fx.Node¶
- property indexing_dims: list[Symbol]¶
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- mode: ShuffleMode¶
- offset: int¶
- tkw_op_name: str = 'shuffle'¶
- width: int¶
- class wave_lang.kernel.ops.wave_ops.Sin(arg: Node)¶
NewSubclass(arg: ‘fx.Node’)
- tkw_op_name: str = 'sin'¶
- class wave_lang.kernel.ops.wave_ops.Sinh(arg: Node)¶
NewSubclass(arg: ‘fx.Node’)
- tkw_op_name: str = 'sinh'¶
- class wave_lang.kernel.ops.wave_ops.Softsign(arg: Node, logit_cap: float = 30.0, apply_scaling: bool = False, head_dim: int = None)¶
NewSubclass(arg: ‘fx.Node’, logit_cap: ‘float’ = 30.0, apply_scaling: ‘bool’ = False, head_dim: ‘int’ = None)
- tkw_op_name: str = 'softsign'¶
- class wave_lang.kernel.ops.wave_ops.SoftsignOp(arg: 'fx.Node', logit_cap: 'float' = 30.0, apply_scaling: 'bool' = False, head_dim: 'int' = None)¶
- apply_scaling: bool = False¶
- arg: Node¶
- head_dim: int = None¶
- property indexing_dims: list[Symbol]¶
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- logit_cap: float = 30.0¶
- class wave_lang.kernel.ops.wave_ops.Sqrt(arg: Node)¶
NewSubclass(arg: ‘fx.Node’)
- tkw_op_name: str = 'sqrt'¶
- class wave_lang.kernel.ops.wave_ops.Sub(lhs: Any, rhs: Any)¶
NewSubclass(lhs: ‘Any’, rhs: ‘Any’)
- tkw_op_name: str = 'sub'¶
- class wave_lang.kernel.ops.wave_ops.Sum(arg: Node | list[Node], init: Node = None, dim: Any | None = None, block: bool | None = False)¶
NewSubclass(arg: ‘fx.Node | list[fx.Node]’, init: ‘fx.Node’ = None, dim: ‘Optional[Any]’ = None, block: ‘Optional[bool]’ = False)
- tkw_op_name: str = 'sum'¶
- class wave_lang.kernel.ops.wave_ops.Tanh(arg: Node)¶
NewSubclass(arg: ‘fx.Node’)
- tkw_op_name: str = 'tanh'¶
- class wave_lang.kernel.ops.wave_ops.TanhApprox(arg: Node)¶
NewSubclass(arg: ‘fx.Node’)
- tkw_op_name: str = 'tanh_approx'¶
- class wave_lang.kernel.ops.wave_ops.TensorCounterWait(count: int = 0)¶
Wait for the tensor counter to reach the specified value. Generates rocdl.s.wait.tensorcnt instruction.
NOTE: This operation is only supported on gfx1250 targets.
- count: int = 0¶
- property has_side_effects: bool¶
- tkw_op_name: str = 'tensor_counter_wait'¶
- class wave_lang.kernel.ops.wave_ops.TensorLoadToLDS(src: 'list[Memory]', dst: 'list[Memory]', element_type: 'DataType', distributed_shape: 'dict[IndexSymbol, IndexExpr]', shared_tile_index: 'dict[IndexSymbol, IndexSequence]', global_tile_index: 'dict[IndexSymbol, IndexSequence]', bounds: 'dict[IndexSymbol, IndexExpr]', multicast_mask: 'Optional[IndexExpr]' = None, input_selector: 'IndexSymbol | int' = 0)¶
- bounds: dict[Symbol, Expr]¶
- distributed_shape: dict[Symbol, Expr]¶
- dst: list[Memory]¶
- element_type: DataType¶
- global_tile_index: dict[Symbol, IndexSequence]¶
- property has_side_effects: bool¶
- input_selector: Symbol | int = 0¶
- multicast_mask: Expr | None = None¶
- src: list[Memory]¶
- tkw_op_name: str = 'tensor_load_to_lds'¶
- class wave_lang.kernel.ops.wave_ops.TopkOp(arg: Node, k_dim: Symbol, dim: Symbol)¶
Represents a TopK computation.
arg: Source tensor Register value to find top-k from k_dim: Dimension symbol representing the K size (number of top elements to select) dim: which symbolic dim to perform topk on. This dimension will be replaced by k_dim in the output.
- arg: Node¶
- dim: Symbol¶
- property indexing_dims: list[Symbol]¶
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- property init¶
- k_dim: Symbol¶
- property reduction_dim: Symbol¶
- tkw_op_name: str = 'topk'¶
- class wave_lang.kernel.ops.wave_ops.Truediv(lhs: Any, rhs: Any)¶
NewSubclass(lhs: ‘Any’, rhs: ‘Any’)
- tkw_op_name: str = 'truediv'¶
- class wave_lang.kernel.ops.wave_ops.UnaryPyOp(arg: Node)¶
Represents a unary python operator.
- arg: Node¶
- property indexing_dims: list[Symbol]¶
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- property py_operator: str¶
- final class wave_lang.kernel.ops.wave_ops.Unknown(args: Sequence[Any], kwargs: dict[Any, Any])¶
Represents an fx.Node that has no corresponding CustomNode class.
- args: Sequence[Any]¶
- custom_string(value_map: dict[str, str]) str¶
- kwargs: dict[Any, Any]¶
- class wave_lang.kernel.ops.wave_ops.WorkgroupBarrier¶
Represents a synchronization of all threads in a workgroup. Threads will wait on a WorkgroupBarrier until all the threads in the workgroup has called a WorkgroupBarrier(does not have to be in the same location).
- property has_side_effects: bool¶
- tkw_op_name: str = 'workgroup_barrier'¶
- class wave_lang.kernel.ops.wave_ops.Write(register_: 'fx.Proxy', memory: 'fx.Proxy', elements_per_thread: 'Optional[Any]' = None, mapping: 'Optional[IndexMapping]' = None, mapping_dynamic_vals: 'tuple[fx.Node, ...]' = (), bounds: 'Optional[dict[IndexSymbol, IndexExpr]]' = None, flags: 'MemoryAccessFlags' = MemoryAccessFlags.NONE, source: 'Optional[tuple[IndexExpr]]' = None, target: 'Optional[tuple[IndexExpr]]' = None)¶
- bounds: dict[Symbol, Expr] | None = None¶
- elements_per_thread: Any | None = None¶
- flags: MemoryAccessFlags = 0¶
- get_derived_indices() list[tuple[dict[Symbol, IndexSequence], Node]]¶
- has_identity_mapping() bool¶
Check if mapping between input register and output memory is identity.
- property indexing_dims: list[Symbol]¶
- infer_type(*args)¶
Infer the type of this operator using the types of its arguments.
- is_contiguous_vec(constraints, target: str) bool¶
Check if op can be lowered to contiguous vector ops
If False we will have to lower it to gather
- mapping: IndexMapping | None = None¶
- mapping_dynamic_vals: tuple[Node, ...] = ()¶
- memory: Proxy¶
- property memory_type: Memory¶
- register_: Proxy¶
- property register_index: dict[Symbol, IndexSequence]¶
- property register_type: Register¶
- source: tuple[Expr] | None = None¶
- target: tuple[Expr] | None = None¶
- tkw_op_name: str = 'write'¶
- transform_index_backwards(index: dict[Symbol, IndexSequence], arg: Node) dict[Symbol, IndexSequence]¶
Propagate index backwards.
Dynamic values potentially can have non-identity mapping, so we need to update index when walking from the node to dyn val arguments.
E.g. if index is $idx and dynamic_val_mappings={N: j // ELEMS_PER_THREAD} resulted arg index will be $idx // ELEMS_PER_THREAD.
- wave_lang.kernel.ops.wave_ops.abs(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.allocate(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.apply_expr(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.atan2(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.atomic_add(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.atomic_min(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.bitcast(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.bounds_check(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.broadcast(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.cast(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.cbrt(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.conditional(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.cos(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.cumsum(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.debug_log(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.define_interface_op(op_name: str) Callable[[T], T]¶
Generate new subclass for op handling, deriving from the base interface class. Generated subclass can be used for emitting the op from compiler/python side, by calling the generated subclass name.
The generated subclass name would be pascal case of the TKW op name. For example: “tkw.op_name” -> “OpName” “tkw.exp2” -> “Exp2”
- wave_lang.kernel.ops.wave_ops.define_op(op_name: str) Callable[[T], T]¶
- wave_lang.kernel.ops.wave_ops.define_py_op(py_op: Callable) Callable[[T], T]¶
Register python internal operators as custom ops. This overloads python operator specific functions such as __add__ of fx.Proxy with a handler in order to control the tracing of the operator and map it to a dynamically created sublclass of UnaryPyOp or BinaryPyOp.
- wave_lang.kernel.ops.wave_ops.eq(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.exp(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.exp2(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.extract(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.extract_slice(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.gather_to_lds(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.ge(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.get_custom(node: Node) CustomOp¶
Get the corresponding CustomOp for a given fx.Node.
- wave_lang.kernel.ops.wave_ops.get_result(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.get_shape_from_bindings(constraints: list[Constraint], target: tuple[Expr]) list[Expr]¶
- wave_lang.kernel.ops.wave_ops.gt(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.has_same_custom_type(lhs_type: Memory, rhs_type: Memory) bool¶
- wave_lang.kernel.ops.wave_ops.iterate(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.le(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.log10(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.log2(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.lt(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.max(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.maximum(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.memory_counter_wait(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.memory_counter_wait_barrier(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.min(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.minimum(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.mma(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.mod(lhs: Register, rhs: Register) Register¶
- wave_lang.kernel.ops.wave_ops.ne(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.permute(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.powf(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.read(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.read_meets_hw_transpose_requirements(read: Read, constraints: list[Constraint], target: str) bool¶
- wave_lang.kernel.ops.wave_ops.reciprocal(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.register(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.remf(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.reshape(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.round(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.roundeven(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.rsqrt(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.scalar(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.scaled_mma(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.scatter_add(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.scheduling_barrier(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.scheduling_group_barrier(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.select(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.self_index(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.set_symbol(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.set_wave_prio(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.shuffle(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.sin(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.sinh(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.softsign(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.sqrt(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.sum(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.tag(expr: Proxy, tag_name: str) Proxy¶
Assign a tag to the result of a Python expression (e.g., arithmetic operators).
This function allows tagging operations that don’t natively support the tag= keyword, such as Python arithmetic operators (*, -, +, /).
- Usage:
# Instead of: result = a * b # Cannot tag this directly
# Use: result = tkw.tag(a * b, “multiply_ab”)
# The tag can then be used in wave_schedule: multiply_ops = tkw.get_node_by_tag(“multiply_ab”)
- Parameters:
expr – The fx.Proxy result of an expression (e.g., a * b, x - y)
tag_name – The tag string to assign to the operation
- Returns:
The same fx.Proxy, allowing chained expressions
- wave_lang.kernel.ops.wave_ops.tanh(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.tanh_approx(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.tensor_counter_wait(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.tensor_load_to_lds(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.topk(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.workgroup_barrier(*args: Any, **kwargs: dict[str, Any])¶
- wave_lang.kernel.ops.wave_ops.write(*args: Any, **kwargs: dict[str, Any])¶