TKW (turbine.kernel.wave)

class wave_lang.kernel.wave.wave.LaunchableWave(constraints: list[Constraint] | None, name: str, eager_function: Callable[[Any], Any])
aot_execute(args, kwargs)
build_initial_pass_pipeline(trace: CapturedTrace, options: WaveCompileOptions, debug_arg_info: list[DebugArgInfo], debug_handlers: list[Any], print_ir_before: Sequence[str] = [], print_ir_after: Sequence[str] = [])
compile_to_mlir(trace: ~wave_lang.kernel._support.tracing.CapturedTrace, context: ~iree.compiler._mlir_libs._site_initialize.<locals>.Context, module_op: ~iree.compiler._mlir_libs._mlir.ir.Module | None = None, options: ~wave_lang.kernel.wave.compile_options.WaveCompileOptions = None)
create_induction_vars(trace: CapturedTrace) None

Creates induction variables for all the reductions in the graph and associates tiling constraints all the reduction dimensions with the appropriate induction variables.

property device_constraints: list[DeviceConstraint]
eager_execute(args, kwargs)
get_workgroup_dims() list[int]

Returns the workgroup dimensions that are not aliased.

property hardware_constraints: list[HardwareConstraint]
infer_device_layout(idxc: IndexingContext)
infer_grid_shape(idxc: IndexingContext)
initialize_reductions(trace: CapturedTrace) None

For each reduction, initializes the reduction count by looking at the tiling constraints associated with the reduction.

initialize_symbolic_constraints() None

For each symbolic constraint, create new constraints for the related symbolic values with appropriate substitutions.

initialize_wave_constraints() None

For each wave constraint, determines the appropriate wave id by looking for workgroup constraints along the same dimension and using information from the hardware constraints.

initialize_workgroup_constraints() None

For kernels that distribute more than three dimensions among workgroups, we need to update the workgroup constraints for dimensions >= 2 with the appropriate workgroup index.

property reordering_constraints: list[ReorderingConstraint]
run_manual_schedule(trace: CapturedTrace, constraints: list[Constraint], schedule: WaveSchedule, use_scheduling_barriers: bool = False)

Runs the manual schedule provided by the user.

property symbolic_constraints: list[HardwareConstraint]
property tiling_constraints: list[TilingConstraint]
update_aliased_workgroup_constraints(workgroup_dims: dict[int, int]) None

This function updates the wg_dim for aliased workgroup constraints.

property wave_constraints: list[WaveConstraint]
property workgroup_constraints: list[WorkgroupConstraint]