Wave Ops

class wave_lang.kernel.ops.wave_ops.Abs(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'abs'
class wave_lang.kernel.ops.wave_ops.Add(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'add'
class wave_lang.kernel.ops.wave_ops.Allocate(shape: tuple[~sympy.core.expr.Expr], distributed_shape: tuple[~sympy.core.expr.Expr], dtype: ~wave_lang.kernel._support.dtype.DataType, address_space: ~wave_lang.kernel.lang.kernel_buffer.AddressSpace = $SHARED_ADDRESS_SPACE, padding: int = 0, parent: ~torch.fx.node.Node | None = None, offset: ~sympy.core.expr.Expr | None = None, tail_padding: int = 0)

Represents an allocation in an address space (such as shared memory).

address_space: AddressSpace = $SHARED_ADDRESS_SPACE
property allocation_size: Expr

Returns the full size of the allocation in bytes including all padding.

distributed_shape: tuple[Expr]
dtype: DataType
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

offset: Expr | None = None
padding: int = 0
parent: Node | None = None
shape: tuple[Expr]
tail_padding: int = 0
tkw_op_name: str = 'allocate'
property type: Memory
property unpadded_dims: dict[Symbol, Expr]
property unpadded_shape: tuple[Expr]
class wave_lang.kernel.ops.wave_ops.And_(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'and_'
class wave_lang.kernel.ops.wave_ops.ApplyExpr(register_: 'fx.Proxy | Sequence[fx.Proxy]', expr: 'Callable')
expr: Callable
property indexing_dims: list[Symbol]
register_: Proxy | Sequence[Proxy]
tkw_op_name: str = 'apply_expr'
property type: Register
class wave_lang.kernel.ops.wave_ops.Atan2(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'atan2'
class wave_lang.kernel.ops.wave_ops.AtomicAddOp(lhs: 'Any', rhs: 'Any', elements_per_thread: 'Optional[Any]' = None, mapping: 'Optional[IndexMapping]' = None, mapping_dynamic_vals: 'tuple[fx.Node, ...]' = ())
tkw_op_name: str = 'atomic_add'
class wave_lang.kernel.ops.wave_ops.AtomicOp(lhs: Any, rhs: Any, elements_per_thread: Any | None = None, mapping: IndexMapping | None = None, mapping_dynamic_vals: tuple[Node, ...] = ())

Represents an atomic operation in the graph. Takes in Register and Memory as inputs and writes the modified value back on to the buffer. Mapping attribute maps the index from wave kernel to the shared memory index the wavegroup operates on.

elements_per_thread: Any | None = None
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

mapping: IndexMapping | None = None
mapping_dynamic_vals: tuple[Node, ...] = ()
property memory_type: Memory
tkw_op_name: str = 'atomic_min'
class wave_lang.kernel.ops.wave_ops.BinaryOpBase(lhs: Any, rhs: Any)

Represents an elementwise binary python operator.

DTYPE requirement: lhs and rhs needs to have the same dtpye. Shape requirement: lhs and rhs either have same shape or

their shape must be broadcastable to one another.

property indexing_dims: list[Symbol]
infer_shape() Any
lhs: Any
property py_operator: str
rhs: Any
class wave_lang.kernel.ops.wave_ops.BinaryPyOp(lhs: 'Any', rhs: 'Any')
infer_type(*args)

Infer the type of this operator using the types of its arguments.

class wave_lang.kernel.ops.wave_ops.BitcastOp(arg: Node, dtype: DataType)

Represents a bitcast operation.

arg: Node
dtype: DataType
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

property scale_factor
tkw_op_name: str = 'bitcast'
class wave_lang.kernel.ops.wave_ops.BoundsCheck(index_exprs: dict[Expr, IndexSequence], bounds: dict[Symbol, Expr], mapping: IndexMapping | None = None, mapping_dynamic_vals: tuple[Node, ...] = (), mask_bounds: dict[Symbol, Expr] | None = None)

Checks if a dimension is within a bound.

bounds: dict[Symbol, Expr]
property has_side_effects: bool
index_exprs: dict[Expr, IndexSequence]
mapping: IndexMapping | None = None
mapping_dynamic_vals: tuple[Node, ...] = ()
mask_bounds: dict[Symbol, Expr] | None = None
tkw_op_name: str = 'bounds_check'
class wave_lang.kernel.ops.wave_ops.Broadcast(arg: Node, target_shape: Sequence[Symbol] = None)

Represents a Broadcast operation.

arg: Source tensor/value to broadcast target_shape: symbolic target broadcast shape.

arg: Node
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

target_shape: Sequence[Symbol] = None
tkw_op_name: str = 'broadcast'
class wave_lang.kernel.ops.wave_ops.CastOp(arg: Node, dtype: DataType)

Represents a cast operation.

arg: Node
dtype: DataType
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

tkw_op_name: str = 'cast'
class wave_lang.kernel.ops.wave_ops.Cbrt(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'cbrt'
class wave_lang.kernel.ops.wave_ops.ComparisonPyOp(lhs: 'Any', rhs: 'Any')
infer_type(*args)

Infer the type of this operator using the types of its arguments.

class wave_lang.kernel.ops.wave_ops.Conditional(condition: 'fx.Proxy | IndexExpr', subgraph_name: 'str', implicit_captures: 'Sequence[fx.Proxy]')
condition: Proxy | Expr
implicit_captures: Sequence[Proxy]
property indexing_dims: list[Symbol]
subgraph_name: str
tkw_op_name: str = 'conditional'
class wave_lang.kernel.ops.wave_ops.Cos(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'cos'
class wave_lang.kernel.ops.wave_ops.Cumsum(arg: Node | list[Node], init: Node | None = None, dim: Symbol | None = None)

NewSubclass(arg: ‘fx.Node | list[fx.Node]’, init: ‘Optional[fx.Node]’ = None, dim: ‘Optional[IndexSymbol]’ = None)

tkw_op_name: str = 'cumsum'
class wave_lang.kernel.ops.wave_ops.CustomOp

Base class for all custom fx nodes.

add_to_graph(region_graph: RegionGraph, type: Any = None, loc: FileLineColInfo | StackTraceInfo | None = None, tag: str | None = None) Node
copy(new_name: str | None = None, new_graph: ~torch.fx.graph.Graph | None = None, arg_transform: ~typing.Callable[[~typing.Any], ~typing.Any] | None = <function CustomOp.<lambda>>, anchor: ~torch.fx.node.Node | None = None) Self

Returns a duplicate of this node.

copy_core_attributes(new_node: Node)

Copy core attributes from the current node to the new node.

custom_string(value_map: dict[str, str]) str
erase()

Erase the current node from the graph where it exists.

property expanded_dims: dict[Symbol, int]

During expansion each node is expanded along its indexing dimensions. The expanded_dims property stores the dimensions along which the node has been expanded as well as the scaling along that dimension.

For example, a node with indexing dimensions [M, N] with dimensional scaling {M: 2, N: 2}, will be expanded to 4 nodes, with each expanded node mapping to the following expanded_dims {M: 0, N: 0}, {M: 0, N: 1}, {M: 1, N: 0}, {M: 1, N: 1}.

classmethod from_fx_node(node: Node) CustomOpT
fx_node: Node | None = None
get_node_arg_index(arg: CustomOp) CustomOp | list[CustomOp] | None
graph: Graph | None = None
classmethod handle(graph: RegionGraph, *args, **kwargs) Node
property has_side_effects: bool
property index: dict[Symbol, IndexSequence] | None
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

property location: FileLineColInfo | StackTraceInfo | None
property name: str
property node_args: dict[int, Any]

Returns the args to this custom op using subclasses of CustomOp if possible.

property pre_expansion_id: int
replace_all_uses_with(new_node: CustomOp | Node)

Replace all uses of the current node with the new node.

replace_all_uses_with_except(new_node: CustomOp | Node, except_nodes: list[CustomOp])

Replace all uses of the current node with the new node except for the nodes in except_nodes.

replacement_location_propagate(new_node: CustomOp | Node | list[Node])

Set the new_node location if it doesn’t have one.

property rrt
property scheduling_parameters
property tag: str | None
tkw_op_name: str = 'unknown'
transform_index(index: dict[Symbol, IndexSequence]) dict[Symbol, IndexSequence]

Transform the index of the node based on the provided mapping.

transform_index_backwards(index: dict[Symbol, IndexSequence], arg: Node) dict[Symbol, IndexSequence]

Transform the index of the node when propagating index backwards, i.e. from node to its arguments.

property type: Any
update_arg(idx_or_name: int | str | Node, value: CustomOp | Node)

Update the value of an argument in the node while keeping the underlying fx.Node consistent.

property users: list[Any]

Returns the users of this custom op using subclasses of CustomOp if possible.

property vector_shapes: dict[Symbol, int]
class wave_lang.kernel.ops.wave_ops.DebugLog(register_: ~torch.fx.proxy.Proxy, label: str | None = None, extra_iteration_dimensions: list[tuple[~sympy.core.symbol.Symbol, ~sympy.core.symbol.Symbol, int]] | None = None, mapping: ~wave_lang.kernel.lang.wave_types.IndexMapping | None = None, mapping_dynamic_vals: tuple[~torch.fx.node.Node, ...] = (), printer: ~typing.Callable[[str, ~typing.Any], ~typing.Any] | None = <built-in function print>, handler: ~typing.Callable[[dict[str, ~typing.Any]], ~typing.Any] | None = None)

An op for debugging. Represents a write to an implicit global memory location. The kernel will implicitly have an extra memory input added that will be injected by the Python kernel launcher. The memory can be accessed by passing an an extra keyword to kernel invokation “debug_logs” with an empty dictionary. The dictionary will be mutated, adding the label`s as keys that map to nested dictionaries. The nested dictionaries have a `value field with the log tensor, and other keys (eg. symbolic_shape).

IE the debug_logs dictionary will look like: {LABEL: {“value”: LOG_TENSOR, (other metadata keys) …}}`

Note that the logs collected in the debug_logs field, or handled by printer or handler represent a global view of the log after all writes, not limited to any one wave or loop iteration.

The optional printer argument should be a function that accepts a string (the log’s label) and the value of the log itself (a Torch tensor). The default value for this is print, though note that it will probably print an abbreviated view of the global tensor.

The optional handler argument should be a function that accepts the whole debug_logs object (IE all logs, not just one). The handler function gives a way to specify something like a viewer for all logs, but specify it inline among print functions rather than separately.

The optional extra_iteration_dimensions argument allows you to add extra dimensions to capture values from multiple iterations of a loop. It takes a list of tuples, where each tuple contains (dimension_name, iteration_axis, max_iterations). The dimension_name must be a unique symbol, and will be the name of the dimension in the symbolic shape of the output. The iteration_axis must be the axis of an Iterate operation, IE the dimension being reduced in the iteration. The max_iterations argument is a positive integer that represents the maximum number of iterations to store. Note that if you give too few, later iterations will currently overwrite the final iteration slot.

The API and semantics of this operation are not yet stable, but since it is just a debugging tool, you want to take any debug logging out of your kernel before shipping it anyway.

extra_iteration_dimensions: list[tuple[Symbol, Symbol, int]] | None = None
handler: Callable[[dict[str, Any]], Any] | None = None
property has_side_effects: bool
infer_type(*args)

Infer the type of this operator using the types of its arguments.

label: str | None = None
mapping: IndexMapping | None = None
mapping_dynamic_vals: tuple[Node, ...] = ()
property memory: Proxy | None
printer(*, sep=' ', end='\n', file=None, flush=False)

Prints the values to a stream, or to sys.stdout by default.

sep

string inserted between values, default a space.

end

string appended after the last value, default a newline.

file

a file-like object (stream); defaults to the current sys.stdout.

flush

whether to forcibly flush the stream.

register_: Proxy
tkw_op_name: str = 'debug_log'
class wave_lang.kernel.ops.wave_ops.Eq(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'eq'
class wave_lang.kernel.ops.wave_ops.Exp(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'exp'
class wave_lang.kernel.ops.wave_ops.Exp2(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'exp2'
class wave_lang.kernel.ops.wave_ops.Extract(register_: Proxy, offset: Expr | int)

Op Rationale:

Extract is an op used to represent extracting of a scalar from TKW’s 1-D vector on the specified index.

This can also be viewed as indexing/slicing on the fastest dimension. Hence, the semantic of this op is designed to see itself as a reduction on the indexed/fastest dimension.

infer_type(*args)

Infer the type of this operator using the types of its arguments.

offset: Expr | int
register_: Proxy
tkw_op_name: str = 'extract'
class wave_lang.kernel.ops.wave_ops.ExtractSlice(register_: 'fx.Proxy', offset: 'tuple[IndexExpr]', size: 'tuple[IndexExpr]', stride: 'tuple[IndexExpr]')
offset: tuple[Expr]
property rank: int
register_: Proxy
size: tuple[Expr]
stride: tuple[Expr]
tkw_op_name: str = 'extract_slice'
property type: Register
class wave_lang.kernel.ops.wave_ops.GatherToLDS(src: Memory, dst: Memory, src_index: dict[Symbol, IndexSequence], dst_index: dict[Symbol, IndexSequence], dtype: DataType, elements_per_thread: Expr | int | None, src_mapping: IndexMapping | None, dst_mapping: IndexMapping | None, src_bounds: dict[Symbol, Expr] | None, src_mapping_dynamic_vals: tuple[Node, ...] = (), dst_mapping_dynamic_vals: tuple[Node, ...] = ())

Represents an instruction that performs direct load from global to lds. Source node points to the global memory to load from and the destination node points to shared memory.

dst: Memory
dst_index: dict[Symbol, IndexSequence]
dst_mapping: IndexMapping | None
dst_mapping_dynamic_vals: tuple[Node, ...] = ()
dtype: DataType
elements_per_thread: Expr | int | None
src: Memory
src_bounds: dict[Symbol, Expr] | None
src_index: dict[Symbol, IndexSequence]
src_mapping: IndexMapping | None
src_mapping_dynamic_vals: tuple[Node, ...] = ()
tkw_op_name: str = 'gather_to_lds'
class wave_lang.kernel.ops.wave_ops.Ge(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'ge'
class wave_lang.kernel.ops.wave_ops.GetResult(value: 'fx.Node', res_idx: 'int')
property distributed_shape
property index: dict[Symbol, IndexSequence]
property indexing_dims: list[Expr]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

res_idx: int
tkw_op_name: str = 'get_result'
value: Node
class wave_lang.kernel.ops.wave_ops.Getitem(value: Node, res_idx: int)

NewSubclass(value: ‘fx.Node’, res_idx: ‘int’)

tkw_op_name: str = 'getitem'
class wave_lang.kernel.ops.wave_ops.Gt(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'gt'
class wave_lang.kernel.ops.wave_ops.Invert(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'invert'
class wave_lang.kernel.ops.wave_ops.IterArg(_name: str, _type: Type[DataType] | Type[KernelBuffer] | None = None)

Represents a specific placeholder node in the graph that is an iter arg of a reduction node. IterArgs can be of type Register or Memory with a Shared memory address space.

property distributed_shape
infer_type(*args)

Infer the type of this operator using the types of its arguments.

property iter_idx
parent_op()
class wave_lang.kernel.ops.wave_ops.Iterate(axis: 'IndexSymbol', init_args: 'Sequence[Any]', subgraph_name: 'str', implicit_captures: 'Sequence[fx.Proxy]', step: 'int' = 1, start: 'Optional[IndexExpr]' = None, condition: 'Optional[IndexExpr]' = None)
axis: Symbol
condition: Expr | None = None
property count: int | None
implicit_captures: Sequence[Proxy]
property index: list[dict[Symbol, IndexSequence]]
property indexing_dims: list[Symbol] | list[list[Symbol]]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

init_args: Sequence[Any]
iter_args(graph: Graph | None = None) list[Node]
outputs(graph: Graph | None = None) list[Node]
start: Expr | None = None
step: int = 1
subgraph_name: str
tkw_op_name: str = 'iterate'
class wave_lang.kernel.ops.wave_ops.Le(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'le'
class wave_lang.kernel.ops.wave_ops.Log10(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'log10'
class wave_lang.kernel.ops.wave_ops.Log2(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'log2'
class wave_lang.kernel.ops.wave_ops.Lt(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'lt'
class wave_lang.kernel.ops.wave_ops.MMA(lhs: 'fx.Node', rhs: 'fx.Node', acc: 'fx.Node', mma_type: "Optional['MMAType'] | 'GenericDot'" = None)
acc: fx.Node
property acc_index: dict[Symbol, IndexSequence]
property acc_type: Memory
custom_string(value_map: dict[str, str]) str
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

lhs: fx.Node
property lhs_index: dict[Symbol, IndexSequence]
property lhs_type: Memory
mma_type: 'MMAType' | None | 'GenericDot' = None
operand_index(operand_map: dict[Symbol, int], shape: list[Expr]) dict[Symbol, IndexSequence]
property reduction_dim: Symbol
rhs: fx.Node
property rhs_index: dict[Symbol, IndexSequence]
property rhs_type: Memory
tkw_op_name: str = 'mma'
class wave_lang.kernel.ops.wave_ops.MMABase
class wave_lang.kernel.ops.wave_ops.Max(arg: Node | list[Node], init: Node = None, dim: Any | None = None, block: bool | None = False)

NewSubclass(arg: ‘fx.Node | list[fx.Node]’, init: ‘fx.Node’ = None, dim: ‘Optional[Any]’ = None, block: ‘Optional[bool]’ = False)

tkw_op_name: str = 'max'
class wave_lang.kernel.ops.wave_ops.Maximum(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'maximum'
class wave_lang.kernel.ops.wave_ops.MemoryCounterWait(load: int | None = None, store: int | None = None, ds: int | None = None, exp: int | None = None)

Wait for the specified counters to be less-than or equal-to the provided values before continuing.

ds: int | None = None
exp: int | None = None
property has_side_effects: bool
load: int | None = None
store: int | None = None
tkw_op_name: str = 'memory_counter_wait'
class wave_lang.kernel.ops.wave_ops.Min(arg: Node | list[Node], init: Node = None, dim: Any | None = None, block: bool | None = False)

NewSubclass(arg: ‘fx.Node | list[fx.Node]’, init: ‘fx.Node’ = None, dim: ‘Optional[Any]’ = None, block: ‘Optional[bool]’ = False)

tkw_op_name: str = 'min'
class wave_lang.kernel.ops.wave_ops.Minimum(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'minimum'
class wave_lang.kernel.ops.wave_ops.Mod(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'mod'
class wave_lang.kernel.ops.wave_ops.Mul(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'mul'
class wave_lang.kernel.ops.wave_ops.Ne(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'ne'
class wave_lang.kernel.ops.wave_ops.Neg(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'neg'
class wave_lang.kernel.ops.wave_ops.NestedRegionOp
captured_vars(graph: Graph) list[Node]

Nodes that are placeholders and are not iter args are captured vars.

erase()

Erase the current node from the graph where it exists.

get_captured_fx_node(graph: Graph, outer_node: Node) Node | None
get_outer_node(outer_node: Node) Node
get_root_graph()

Return the “root”/outermost layer of our computation graph. This is done by iteratively accessing parent_graph of current graph. This is done until we find the “root” graph who will have “subgraph” attribute.

classmethod handle(graph: RegionGraph, *args, **kwargs)

Base handle method for nested region operations. Extracts tag from kwargs and sets it on the node and underlying fx.Node.

class wave_lang.kernel.ops.wave_ops.NewRegister(shape: 'tuple[IndexExpr, ...]', dtype: 'DataType', value: 'float')
dtype: DataType
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

shape: tuple[Expr, ...]
tkw_op_name: str = 'register'
value: float
class wave_lang.kernel.ops.wave_ops.NewScalar(value: 'float | IndexExpr', dtype: 'DataType')
dtype: DataType
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

tkw_op_name: str = 'scalar'
value: float | Expr
class wave_lang.kernel.ops.wave_ops.Or_(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'or_'
class wave_lang.kernel.ops.wave_ops.Output(return_vals: Sequence[Any])

Represents an output node in the graph, representing the return value of a traced function.

add_to_graph(region_graph: RegionGraph, loc: FileLineColInfo | StackTraceInfo | None = None) Node
classmethod from_fx_node(node: Node) CustomOpT
property has_side_effects: bool
return_vals: Sequence[Any]
tkw_op_name: str = 'output'
class wave_lang.kernel.ops.wave_ops.Permute(arg: Node, target_shape: Sequence[Expr])

Represents a permute operation that permutes arg into the target shape.

arg: Node
property indexing_dims: list[Expr]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

target_shape: Sequence[Expr]
tkw_op_name: str = 'permute'
transform_index(index: dict[Symbol, IndexSequence]) dict[Symbol, IndexSequence]

The permute operation swaps the strides of the permuted indices. So say we have a permute operation that swaps [B, M, N] to [M, N, B], then we swap the strides of the dimensions.

class wave_lang.kernel.ops.wave_ops.Placeholder(_name: str, _type: Type[DataType] | Type[KernelBuffer] | None = None)

Represents a placeholder node in the graph, i.e. an input to a function.

add_to_graph(region_graph: RegionGraph, loc: FileLineColInfo | StackTraceInfo | None = None) Node
custom_string(value_map: dict[str, str]) str
erase()

Erase the current node from the graph where it exists.

classmethod from_fx_node(node: Node) PlaceholderT
get_captured_fx_node() Node | None
property index: list[dict[Symbol, IndexSequence]]
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

tkw_op_name: str = 'placeholder'
class wave_lang.kernel.ops.wave_ops.Powf(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'powf'
class wave_lang.kernel.ops.wave_ops.Read(memory: 'fx.Proxy', elements_per_thread: 'Optional[Any]' = None, mapping: 'Optional[IndexMapping]' = None, mapping_dynamic_vals: 'tuple[fx.Node, ...]' = (), bounds: 'Optional[dict[IndexSymbol, IndexExpr]]' = None, source: 'Optional[tuple[IndexExpr]]' = None, target: 'Optional[tuple[IndexExpr]]' = None, _write_dependency: 'Optional[list[fx.Node]]' = None)
bounds: dict[Symbol, Expr] | None = None
property dtype: DataType
elements_per_thread: Any | None = None
get_derived_indices() list[tuple[dict[Symbol, IndexSequence], Node]]
has_identity_mapping() bool

Check if mapping between input memory and output register is identity.

property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

is_contiguous_vec(constraints, target: str) bool

Check if op can be lowered to contiguous vector ops

If False we will have to lower it to gather

mapping: IndexMapping | None = None
mapping_dynamic_vals: tuple[Node, ...] = ()
memory: Proxy
property memory_type: Memory
source: tuple[Expr] | None = None
target: tuple[Expr] | None = None
tkw_op_name: str = 'read'
transform_index_backwards(index: dict[Symbol, IndexSequence], arg: Node) dict[Symbol, IndexSequence]

Propagate index backwards.

Dynamic values potentially can have non-identity mapping, so we need to update index when walking from the node to dyn val arguments.

E.g. if index is $idx and dynamic_val_mappings={N: j // ELEMS_PER_THREAD} resulted arg index will be $idx // ELEMS_PER_THREAD.

property write_dependency: Node
class wave_lang.kernel.ops.wave_ops.Reciprocal(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'reciprocal'
class wave_lang.kernel.ops.wave_ops.ReduceOp(arg: Node | list[Node], init: Node = None, dim: Any | None = None, block: bool | None = False)

Represents a Reduce computation.

arg: Source tensor/value to reduce init: init/accumulator for reduce dim: which symbolic dim to reduce. block: When set to true, reduce across block, else reduce across warp.

arg: Node | list[Node]
block: bool | None = False
dim: Any | None = None
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

init: Node = None
property reduction_dim: Symbol
class wave_lang.kernel.ops.wave_ops.Remf(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'remf'
class wave_lang.kernel.ops.wave_ops.Reshape(args: Node | Sequence[Node], target_vector_shape: dict[Symbol, int])

Represents a reshape operation that reshapes vectors along the same dimension.

args: Node | Sequence[Node]
property indexing_dims: list[Expr]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

target_vector_shape: dict[Symbol, int]
tkw_op_name: str = 'reshape'
class wave_lang.kernel.ops.wave_ops.Round(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'round'
class wave_lang.kernel.ops.wave_ops.Roundeven(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'roundeven'
class wave_lang.kernel.ops.wave_ops.Rsqrt(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'rsqrt'
class wave_lang.kernel.ops.wave_ops.ScaledMMA(lhs: 'fx.Node', lhs_scale: 'fx.Node', rhs: 'fx.Node', rhs_scale: 'fx.Node', acc: 'fx.Node', mma_type: "Optional['ScaledMMAType']" = None)
acc: fx.Node
property acc_index: dict[Symbol, IndexSequence]
property acc_type: Memory
align_index(constraints: list[Constraint]) None
custom_string(value_map: dict[str, str]) str
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

lhs: fx.Node
property lhs_index: dict[Symbol, IndexSequence]
lhs_scale: fx.Node
property lhs_scale_index: dict[Symbol, IndexSequence]
property lhs_scale_type: Memory
property lhs_type: Memory
mma_type: 'ScaledMMAType' | None = None
operand_index(operand_map: dict[Symbol, int], shape: list[Expr]) dict[Symbol, IndexSequence]
property reduction_dim: Symbol
rhs: fx.Node
property rhs_index: dict[Symbol, IndexSequence]
rhs_scale: fx.Node
property rhs_scale_index: dict[Symbol, IndexSequence]
property rhs_scale_type: Memory
property rhs_type: Memory
tkw_op_name: str = 'scaled_mma'
class wave_lang.kernel.ops.wave_ops.ScanOp(arg: Node | list[Node], init: Node | None = None, dim: Symbol | None = None)

Base class for all scan-style operations (e.g., cumsum).

arg: Source tensor/value to scan. init: Optional initial value. dim: Symbolic dimension along which to scan.

arg: Node | list[Node]
dim: Symbol | None = None
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

init: Node | None = None
property scan_dim: Symbol
class wave_lang.kernel.ops.wave_ops.ScatterAdd(register_src: Node, register_idx: Node, dim: Expr, memory: Node, mapping: IndexMapping, elements_per_thread: int | None = 1, bounds: dict[Symbol, Expr] | None = None)

ScatterAdd performs element-wise accumulation from a source register into shared memory (LDS), at locations determined by the index register along a specified dimension.

Limitations: - Only intra-workgroup scattering is supported (i.e., within shared memory / LDS), assuming a single wave. - Multi-wave execution is not guaranteed to be safe: synchronization issues may occur when threads write to the same index. Further investigation is needed. - The operation supports multiple elements per thread, assuming the non-scatter dimension is large enough (i.e., > elements_per_thread).

bounds: dict[Symbol, Expr] | None = None
dim: Expr
elements_per_thread: int | None = 1
property has_side_effects: bool
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

mapping: IndexMapping
memory: Node
property memory_type: Memory
register_idx: Node
property register_index: dict[Symbol, IndexSequence]
register_src: Node
property register_type: Register
tkw_op_name: str = 'scatter_add'
class wave_lang.kernel.ops.wave_ops.SchedulingBarrier(operations: list[Operation])

Represents a scheduling barrier in the graph. Takes in a list of operations that are allowed to cross the barrier.

operations: list[Operation]
tkw_op_name: str = 'scheduling_barrier'
class wave_lang.kernel.ops.wave_ops.SchedulingGroupBarrier(instructions: dict[Operation, int], sync_id: int)

Represents a scheduling group barrier in the graph. The scheduling group barrier defines scheduling groups. Each scheduling group contains different instructions in a specific order. The sync_id identifies scheduling groups that need to be aware of each other.

instructions: dict[Operation, int]
sync_id: int
tkw_op_name: str = 'scheduling_group_barrier'
class wave_lang.kernel.ops.wave_ops.SelectOp(cond: 'fx.Node', if_true: 'fx.Node', if_false: 'fx.Node')
cond: Node
if_false: Node
if_true: Node
property indexing_dims: list[Symbol]
infer_broadcast_shape(cond_type: Register, t_type: Register, f_type: Register)
infer_type(*args)

Infer the type of this operator using the types of its arguments.

tkw_op_name: str = 'select'
class wave_lang.kernel.ops.wave_ops.SelfIndex(dim: 'IndexExpr', dtype: 'DataType', elements_per_thread: 'Optional[IndexExpr | int]' = None)
dim: Expr
dtype: DataType
elements_per_thread: Expr | int | None = None
property indexing_dims: list[Symbol]
tkw_op_name: str = 'self_index'
property type: Register
class wave_lang.kernel.ops.wave_ops.SetSymbol(symbol: 'IndexExpr', register_: 'fx.Proxy')
property has_side_effects: bool
property indexing_dims: list[Symbol]
register_: Proxy
symbol: Expr
tkw_op_name: str = 'set_symbol'
property type: Register
class wave_lang.kernel.ops.wave_ops.SetWavePrio(priority: int)

An op that sets/tells hardware what level of priority certain instructions/region is. This is useful for ping-pong or general case where two Waves share the same SIMD, but we want to tell the SIMD to prioritize on wave or the other.

property has_side_effects: bool
priority: int
tkw_op_name: str = 'set_wave_prio'
class wave_lang.kernel.ops.wave_ops.SharedMemoryBarrier(wait_async_ops: bool = False, tensor_wait: bool = False)

Represents a shared memory barrier in the graph.

property has_side_effects: bool
tensor_wait: bool = False
tkw_op_name: str = 'shared_memory_barrier'
wait_async_ops: bool = False
class wave_lang.kernel.ops.wave_ops.SharedMemoryBarrierSignal(barId: int = 0, tensor_wait: bool = False)

Represents a shared memory barrier signal in the graph. (gfx12) Argument specifies which barrier to signal. [1:16]: named barriers

0: NOOP

-1: works as s_barrier -2: trap barrier -3: cluster barrier

barId: int = 0
property has_side_effects: bool
tensor_wait: bool = False
tkw_op_name: str = 'shared_memory_barrier_signal'
class wave_lang.kernel.ops.wave_ops.SharedMemoryBarrierWait(barId: int = 0)

Wait for all waves in a WG to signal the barrier before proceeding. (gfx12) synchronize waves within a WG. Argument specifies which barrier to wait on. [1:16]: named barriers

0: NOOP

-1: works as s_barrier -2: trap barrier -3: cluster barrier

barId: int = 0
property has_side_effects: bool
tkw_op_name: str = 'shared_memory_barrier_wait'
class wave_lang.kernel.ops.wave_ops.ShuffleOp(arg: fx.Node, offset: int, width: int, mode: ShuffleMode)

Represents a shuffle.xor op.

arg: value/vector to shuffle. offset: xor offset. width: xor width.

arg: fx.Node
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

mode: ShuffleMode
offset: int
tkw_op_name: str = 'shuffle'
width: int
class wave_lang.kernel.ops.wave_ops.Sin(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'sin'
class wave_lang.kernel.ops.wave_ops.Sinh(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'sinh'
class wave_lang.kernel.ops.wave_ops.Softsign(arg: Node, logit_cap: float = 30.0, apply_scaling: bool = False, head_dim: int = None)

NewSubclass(arg: ‘fx.Node’, logit_cap: ‘float’ = 30.0, apply_scaling: ‘bool’ = False, head_dim: ‘int’ = None)

tkw_op_name: str = 'softsign'
class wave_lang.kernel.ops.wave_ops.SoftsignOp(arg: 'fx.Node', logit_cap: 'float' = 30.0, apply_scaling: 'bool' = False, head_dim: 'int' = None)
apply_scaling: bool = False
arg: Node
head_dim: int = None
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

logit_cap: float = 30.0
class wave_lang.kernel.ops.wave_ops.Sqrt(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'sqrt'
class wave_lang.kernel.ops.wave_ops.Sub(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'sub'
class wave_lang.kernel.ops.wave_ops.Sum(arg: Node | list[Node], init: Node = None, dim: Any | None = None, block: bool | None = False)

NewSubclass(arg: ‘fx.Node | list[fx.Node]’, init: ‘fx.Node’ = None, dim: ‘Optional[Any]’ = None, block: ‘Optional[bool]’ = False)

tkw_op_name: str = 'sum'
class wave_lang.kernel.ops.wave_ops.Tanh(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'tanh'
class wave_lang.kernel.ops.wave_ops.TanhApprox(arg: Node)

NewSubclass(arg: ‘fx.Node’)

tkw_op_name: str = 'tanh_approx'
class wave_lang.kernel.ops.wave_ops.TensorLoadToLDS(src: 'Memory', dst: 'Memory', element_type: 'DataType', distributed_shape: 'list[IndexExpr]', shared_tile_index: 'dict[IndexSymbol, IndexSequence]', global_tile_index: 'dict[IndexSymbol, IndexSequence]', bounds: 'dict[IndexSymbol, IndexExpr]', multicast_mask: 'Optional[IndexExpr]' = None)
bounds: dict[Symbol, Expr]
distributed_shape: list[Expr]
dst: Memory
element_type: DataType
global_tile_index: dict[Symbol, IndexSequence]
multicast_mask: Expr | None = None
shared_tile_index: dict[Symbol, IndexSequence]
src: Memory
tkw_op_name: str = 'tensor_load_to_lds'
class wave_lang.kernel.ops.wave_ops.TopkOp(arg: Node, k_dim: Symbol, dim: Symbol)

Represents a TopK computation.

arg: Source tensor Register value to find top-k from k_dim: Dimension symbol representing the K size (number of top elements to select) dim: which symbolic dim to perform topk on. This dimension will be replaced by k_dim in the output.

arg: Node
dim: Symbol
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

property init
k_dim: Symbol
property reduction_dim: Symbol
tkw_op_name: str = 'topk'
class wave_lang.kernel.ops.wave_ops.Truediv(lhs: Any, rhs: Any)

NewSubclass(lhs: ‘Any’, rhs: ‘Any’)

tkw_op_name: str = 'truediv'
class wave_lang.kernel.ops.wave_ops.UnaryPyOp(arg: Node)

Represents a unary python operator.

arg: Node
property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

property py_operator: str
final class wave_lang.kernel.ops.wave_ops.Unknown(args: Sequence[Any], kwargs: dict[Any, Any])

Represents an fx.Node that has no corresponding CustomNode class.

args: Sequence[Any]
custom_string(value_map: dict[str, str]) str
classmethod from_fx_node(node: Node) Unknown
kwargs: dict[Any, Any]
class wave_lang.kernel.ops.wave_ops.WorkgroupBarrier

Represents a synchronization of all threads in a workgroup. Threads will wait on a WorkgroupBarrier until all the threads in the workgroup has called a WorkgroupBarrier(does not have to be in the same location).

property has_side_effects: bool
tkw_op_name: str = 'workgroup_barrier'
class wave_lang.kernel.ops.wave_ops.Write(register_: 'fx.Proxy', memory: 'fx.Proxy', elements_per_thread: 'Optional[Any]' = None, mapping: 'Optional[IndexMapping]' = None, mapping_dynamic_vals: 'tuple[fx.Node, ...]' = (), bounds: 'Optional[dict[IndexSymbol, IndexExpr]]' = None, source: 'Optional[tuple[IndexExpr]]' = None, target: 'Optional[tuple[IndexExpr]]' = None)
bounds: dict[Symbol, Expr] | None = None
elements_per_thread: Any | None = None
get_derived_indices() list[tuple[dict[Symbol, IndexSequence], Node]]
has_identity_mapping() bool

Check if mapping between input register and output memory is identity.

property indexing_dims: list[Symbol]
infer_type(*args)

Infer the type of this operator using the types of its arguments.

is_contiguous_vec(constraints, target: str) bool

Check if op can be lowered to contiguous vector ops

If False we will have to lower it to gather

mapping: IndexMapping | None = None
mapping_dynamic_vals: tuple[Node, ...] = ()
memory: Proxy
property memory_type: Memory
register_: Proxy
property register_index: dict[Symbol, IndexSequence]
property register_type: Register
source: tuple[Expr] | None = None
target: tuple[Expr] | None = None
tkw_op_name: str = 'write'
transform_index_backwards(index: dict[Symbol, IndexSequence], arg: Node) dict[Symbol, IndexSequence]

Propagate index backwards.

Dynamic values potentially can have non-identity mapping, so we need to update index when walking from the node to dyn val arguments.

E.g. if index is $idx and dynamic_val_mappings={N: j // ELEMS_PER_THREAD} resulted arg index will be $idx // ELEMS_PER_THREAD.

wave_lang.kernel.ops.wave_ops.abs(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.allocate(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.apply_expr(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.atan2(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.atomic_add(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.atomic_min(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.bitcast(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.bounds_check(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.broadcast(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.cast(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.cbrt(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.conditional(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.cos(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.cumsum(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.debug_log(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.define_interface_op(op_name: str) Callable[[T], T]

Generate new subclass for op handling, deriving from the base interface class. Generated subclass can be used for emitting the op from compiler/python side, by calling the generated subclass name.

The generated subclass name would be pascal case of the TKW op name. For example: “tkw.op_name” -> “OpName” “tkw.exp2” -> “Exp2”

wave_lang.kernel.ops.wave_ops.define_op(op_name: str) Callable[[T], T]
wave_lang.kernel.ops.wave_ops.define_py_op(py_op: Callable) Callable[[T], T]

Register python internal operators as custom ops. This overloads python operator specific functions such as __add__ of fx.Proxy with a handler in order to control the tracing of the operator and map it to a dynamically created sublclass of UnaryPyOp or BinaryPyOp.

wave_lang.kernel.ops.wave_ops.eq(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.exp(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.exp2(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.extract(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.extract_slice(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.gather_to_lds(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.ge(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.get_custom(node: Node) CustomOp

Get the corresponding CustomOp for a given fx.Node.

wave_lang.kernel.ops.wave_ops.get_result(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.get_shape_from_bindings(constraints: list[Constraint], target: tuple[Expr]) list[Expr]
wave_lang.kernel.ops.wave_ops.gt(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.has_same_custom_type(lhs_type: Memory, rhs_type: Memory) bool
wave_lang.kernel.ops.wave_ops.iterate(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.le(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.log10(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.log2(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.lt(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.max(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.maximum(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.memory_counter_wait(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.min(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.minimum(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.mma(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.mod(lhs: Register, rhs: Register) Register
wave_lang.kernel.ops.wave_ops.ne(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.permute(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.powf(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.read(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.read_meets_hw_transpose_requirements(read: Read, constraints: list[Constraint], target: str) bool
wave_lang.kernel.ops.wave_ops.reciprocal(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.register(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.remf(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.reshape(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.round(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.roundeven(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.rsqrt(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.scalar(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.scaled_mma(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.scatter_add(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.scheduling_barrier(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.scheduling_group_barrier(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.select(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.self_index(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.set_symbol(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.set_wave_prio(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.shared_memory_barrier(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.shared_memory_barrier_signal(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.shared_memory_barrier_wait(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.shuffle(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.sin(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.sinh(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.softsign(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.sqrt(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.sum(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.tanh(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.tanh_approx(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.tensor_load_to_lds(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.topk(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.workgroup_barrier(*args: Any, **kwargs: dict[str, Any])
wave_lang.kernel.ops.wave_ops.write(*args: Any, **kwargs: dict[str, Any])