Wave Codegen

class wave_lang.kernel.compiler.wave_codegen.WaveEmitter(root_sig: BoundKernelSignature, trace: CapturedTrace, constraints: list[Constraint], options: WaveCompileOptions, grid: list[Expr], kernel_name: str)

Emits a warp function as a func with a signature derived from the gm.

OP_HANDLERS: ClassVar[dict[str, Callable[[WaveEmitter, Node], None]]] = {'abs': <function handle_unary_op.<locals>.decorator.<locals>.handle_generic_unary>, 'add': <function handle_binary_op.<locals>.decorator.<locals>.handle_generic_binary>, 'allocate': <function handle_allocate>, 'and_': <function handle_binary_op.<locals>.decorator.<locals>.handle_generic_binary>, 'apply_expr': <function handle_apply_expr>, 'atan2': <function handle_binary_op.<locals>.decorator.<locals>.handle_generic_binary>, 'atomic_add': <function handle_atomic_op.<locals>.decorator.<locals>.handle_generic_atomic>, 'atomic_min': <function handle_atomic_op.<locals>.decorator.<locals>.handle_generic_atomic>, 'bitcast': <function handle_bitcast>, 'bounds_check': <function handle_bounds_check>, 'broadcast': <function handle_broadcast>, 'cast': <function handle_cast>, 'cbrt': <function handle_unary_op.<locals>.decorator.<locals>.handle_generic_unary>, 'conditional': <function handle_conditional>, 'cos': <function handle_unary_op.<locals>.decorator.<locals>.handle_generic_unary>, 'eq': <function handle_binary_op.<locals>.decorator.<locals>.handle_generic_binary>, 'exp': <function handle_unary_op.<locals>.decorator.<locals>.handle_generic_unary>, 'exp2': <function handle_unary_op.<locals>.decorator.<locals>.handle_generic_unary>, 'extract': <function handle_extract>, 'extract_slice': <function handle_extract_slice>, 'floordiv': <function handle_binary_op.<locals>.decorator.<locals>.handle_generic_binary>, 'gather_to_lds': <function handle_gather_to_lds>, 'ge': <function handle_binary_op.<locals>.decorator.<locals>.handle_generic_binary>, 'get_result': <function handle_get_result>, 'gt': <function handle_binary_op.<locals>.decorator.<locals>.handle_generic_binary>, 'invert': <function handle_unary_op.<locals>.decorator.<locals>.handle_generic_unary>, 'iterate': <function handle_iterate>, 'le': <function handle_binary_op.<locals>.decorator.<locals>.handle_generic_binary>, 'log10': <function handle_unary_op.<locals>.decorator.<locals>.handle_generic_unary>, 'log2': <function handle_unary_op.<locals>.decorator.<locals>.handle_generic_unary>, 'lt': <function handle_binary_op.<locals>.decorator.<locals>.handle_generic_binary>, 'maximum': <function handle_binary_op.<locals>.decorator.<locals>.handle_generic_binary>, 'memory_counter_wait': <function handle_memory_counter_wait>, 'memory_counter_wait_barrier': <function handle_memory_counter_wait_barrier>, 'minimum': <function handle_binary_op.<locals>.decorator.<locals>.handle_generic_binary>, 'mma': <function handle_mma>, 'mod': <function handle_binary_op.<locals>.decorator.<locals>.handle_generic_binary>, 'mul': <function handle_binary_op.<locals>.decorator.<locals>.handle_generic_binary>, 'ne': <function handle_binary_op.<locals>.decorator.<locals>.handle_generic_binary>, 'neg': <function handle_unary_op.<locals>.decorator.<locals>.handle_generic_unary>, 'or_': <function handle_binary_op.<locals>.decorator.<locals>.handle_generic_binary>, 'permute': <function handle_permute>, 'powf': <function handle_binary_op.<locals>.decorator.<locals>.handle_generic_binary>, 'read': <function handle_read>, 'reciprocal': <function handle_unary_op.<locals>.decorator.<locals>.handle_generic_unary>, 'register': <function handle_register>, 'remf': <function handle_binary_op.<locals>.decorator.<locals>.handle_generic_binary>, 'reshape': <function handle_reshape>, 'round': <function handle_unary_op.<locals>.decorator.<locals>.handle_generic_unary>, 'roundeven': <function handle_unary_op.<locals>.decorator.<locals>.handle_generic_unary>, 'rsqrt': <function handle_unary_op.<locals>.decorator.<locals>.handle_generic_unary>, 'scalar': <function handle_scalar>, 'scaled_mma': <function handle_scaled_mma>, 'scatter_add': <function handle_scatter_add>, 'scheduling_barrier': <function handle_scheduling_barrier>, 'scheduling_group_barrier': <function handle_scheduling_group_barrier>, 'select': <function handle_select>, 'self_index': <function handle_self_index>, 'set_symbol': <function handle_set_symbol>, 'set_wave_prio': <function handle_set_wave_prio>, 'shared_memory_barrier': <function handle_shared_memory_barrier>, 'shared_memory_barrier_signal': <function handle_shared_memory_barrier_signal>, 'shared_memory_barrier_wait': <function handle_shared_memory_barrier_wait>, 'shuffle': <function handle_shuffle>, 'sin': <function handle_unary_op.<locals>.decorator.<locals>.handle_generic_unary>, 'sinh': <function handle_unary_op.<locals>.decorator.<locals>.handle_generic_unary>, 'softsign': <function handle_softsign>, 'sqrt': <function handle_unary_op.<locals>.decorator.<locals>.handle_generic_unary>, 'sub': <function handle_binary_op.<locals>.decorator.<locals>.handle_generic_binary>, 'tanh': <function handle_unary_op.<locals>.decorator.<locals>.handle_generic_unary>, 'tanh_approx': <function handle_unary_op.<locals>.decorator.<locals>.handle_generic_unary>, 'tensor_counter_wait': <function handle_tensor_counter_wait>, 'tensor_load_to_lds': <function handle_tensor_load_to_lds>, 'truediv': <function handle_binary_op.<locals>.decorator.<locals>.handle_generic_binary>, 'workgroup_barrier': <function handle_workgroup_barrier>, 'write': <function handle_write>}
bind_node_proxies(node: Node, proxies: List[IRProxyValue])
bind_node_proxy(node: Node, proxy: IRProxyValue)

Binds a node’s result to a Python/IR proxy object.

constraints: list[Constraint]
emit(graph: Graph | None = None) Operation
emit_func() Operation
emit_host_func(kernel_func: Operation) Operation
emit_program_invariants()
get_induction_vars_and_syms() tuple[list[OpResult], list[Expr]]
grid: list[Expr]
property hardware_constraint: HardwareConstraint
kernel_name: str
lookup_node_values(node: Node) List[Value]
options: WaveCompileOptions
root_sig: BoundKernelSignature
trace: CapturedTrace