wave_lang.runtime¶
- class wave_lang.runtime.Device(uri: str | None = None, *, device_state: DeviceState | None = None)¶
Represents a low-level device (HalDriver/HalDevice) and scheduling data.
This is the type that user’s interact with as a ‘Device’. Devices can be handled loose-leaf or bound to a thread with a context manager.
- clear()¶
Clears the current device without a context manager.
- compile_target_flags: tuple[str, ...]¶
- create_hal_module() VmModule¶
- driver_id: str¶
- dump_device_info() str¶
- export_torch_tensor: Callable[[HalBufferView, Tensor], Tensor]¶
- finalize_iree_action(external_timepoint: HalExternalTimepoint)¶
- get_type_key_hash(*, hasher: ~typing.Callable[[str], str] = <function <lambda>>)¶
- property hal_device: HalDevice¶
- import_torch_tensor: Callable[[Tensor], HalBufferView]¶
- instance_cache_key: str¶
- setup_iree_action()¶
- property sync: bool¶
- type_cache_key: str¶
- property vm_instance: VmInstance¶
- class wave_lang.runtime.DeviceState(*, driver: str | HalDriver, device: HalDevice | None = None, vm_instance: VmInstance | None = None, enumerated_info: dict | None = None, torch_device: device | None = None, torch_stream: int | None = None, dlpack_device_type_code: int = 0)¶
State for an instantiated HAL device.
Note that the IREE runtime internally manages a global cache of drivers for standard named-access (not custom-constructed) drivers.
- device¶
- dlpack_device_type_code¶
- driver¶
- property enumerated_device_id: int¶
- enumerated_info¶
- property enumerated_name: str¶
- property enumerated_path: str¶
- static from_uri(uri: str) DeviceState¶
- instance¶
- torch_device¶
- torch_stream¶
- class wave_lang.runtime.Launchable(loader: Callable[[Device], Tuple[str, VmModule]] | None, parameter_providers: Sequence[ParameterProvider] = (), is_async: bool = True)¶
Facilities for launching a compiled program (VMFB) on an attached device.
Like the eager custom-op executor, this follows the usual PyTorch rules whereby the device that input tensors reside on dictates where the launch happens. Unlike that flow, this does not include any notion of jitting or caching. It also has APIs for using parameters, etc.
You must manage all compilation/target settings yourself and you merely assert that a given binary is appropriate for launch on a device type. This has various limitations.
- static from_file_cache_only(file_cache_dir: str | Path, *, parameter_providers: Sequence[ParameterProvider] = (), entry_point: str = 'main$async') Launchable¶
Only loads vmfbs from the provided file_cache_dir. Will raise an error if not found.
- static from_vm_module(vm_module_callback: Callable[[Device], VmModule], *, parameter_providers: Sequence[ParameterProvider] = (), entry_point: str = 'main$async')¶
- static jit_compile(source: Any, *, parameter_providers: Sequence[ParameterProvider] = (), entry_point: str = 'main$async', file_cache_dir: str | Path | None = None) Launchable¶
Generates a launchable from a program source (e.g., mlir string). Set a file_cache_dir to enable storing/retrieving artifacts between sessions.
- preload(device: device)¶
Pre-loads (or JIT compiles) for the given torch.device.
- wave_lang.runtime.get_vm_instance() VmInstance¶
- wave_lang.runtime.invoke_vm_function(device: ~wave_lang.runtime.device.Device, is_async: bool, vm_context: ~iree._runtime_libs._runtime.VmContext, vm_function: ~iree._runtime_libs._runtime.VmFunction, arg_list: ~iree._runtime_libs._runtime.VmVariantList, ret_list: ~iree._runtime_libs._runtime.VmVariantList, *, timer: ~typing.Callable[[], float] = <function <lambda>>)¶
Invokes a vm function on a device, adding async fences to the arg_list if is_async.
No checks are made to ensure compatibility between the provided device and vm_function. A timer function (float return) may be provided, and this function will return the invocation time.
op_reg¶
- class wave_lang.runtime.op_reg.AttrArg(v: object)¶
- generate_meta() object¶
- ir_arity: int = 0¶
- is_list: bool = False¶
- maybe_tensor_value: Tensor | None = None¶
- property mlir_type_asm: str¶
- property spec_key: str¶
Generates a key that will be the same for all specializations.
- spec_value: Any | None¶
- v¶
- class wave_lang.runtime.op_reg.CustomOp(*, library: Library, dispatch_key: str | Sequence[str] | None, register_meta: bool, register_impl: bool)¶
Users subclass this in order to register a wave custom op.
- eager_execute(*args)¶
When executing eagerly, allows the CustomOp to provide a direct Python implementation. For AOT/Graph modes, this will not be called.
If the method returns NotImplemented, then a standalone kernel will be compiled and executed.
This is commonly used for ops that have no significance to a single op execution in the PyTorch runtime (e.g. metadata ops), but could theoretically be used to perform any Python analog desired.
- abstract generate(ksel: KernelSelection, kb: KernelBuilder)¶
Generates a kernel based on the KernelSelection.
This method should generate IR into the given KernelBuilder. It can do so by consulting any state set on the KernelSelection. Each KernelSelection.args corresponds to KernelBuilder.args. Unless if the argument was set as ir_arity=0, the argument will be a Value. Otherwise, it will be None. It is recommended to use KernelBuilder.arg(n) to access.
Generation should conclude with a call to KernelBuilder.yield_results.
- static register(op_class: ~typing.Type[~wave_lang.runtime.op_reg.base.CustomOp] | None = None, *, library: ~torch.library.Library = Library(kind=DEF, ns=wave_lang, dispatch_key=)>, dispatch_key: str | ~typing.Sequence[str] | None = None, register_meta: bool = True, register_impl: bool = True) Callable¶
Class decorator for CustomOp implementations.
The decorator will instantiate the class and then replace it with the callable operation that can be used to invoke the kernel.
Typical usage:
``` @CustomOp.register class identity(CustomOp):
…
- abstract select(sel: KernelSelection)¶
Performs kernel selection.
This method has three purposes:
Selects which kernel specialization is needed based on arguments.
Returns the meta tensor results of the operation, effectively completing the transfer function from argument types to result types.
Sets additional metadata that the generate method can use.
The device=”meta” kernel implementation is composed completely by invoking select. For implementation devices, select is called for each invocation. The generate will be called subsequently if the kernel needs to be generated.
- abstract property signature: str¶
PyTorch function signature.
This is in the normal PyTorch kernel registration form. For example:
` my_op(Tensor t) -> Tensor `The signature can have some special tokens in the name part:
“@UNIQUE@”: Generates a name-specific numeric value and replaces it.
- property single_dispatch: bool¶
Indicates whether the CustomOp should be forced into a single dispatch using a util.func pipeline attribute.
It is recommended to only use this for more complicated ops which would not automatically get compiled into a single dispatch. E.g. A fused conv + bias-add + relu custom op.
For eager contexts, this will apply the pipeline attribute to the main$async function.
For aot contexts, this currently does nothing, but could eventually attempt to apply an util.inline.never attribute, in addition to the pipeline attribute, to the function being called by the InlineKernelBuilder.
- class wave_lang.runtime.op_reg.FreeFuncKernelBuilder(ksel: KernelSelection, *, module_body: Block, symbol_table: SymbolTable, func_name: str | None = None, is_public: bool = True)¶
Kernel builder that emits the body of the kernel into a free function.
This is intended to be used when compiling a standalone module that will be directly invoked by the runtime. Further variants exist that generate into a func but also emit a call into another local context.
- static create_module(ksel: ~wave_lang.runtime.op_reg.base.KernelSelection, *, context: ~iree.compiler._mlir_libs._site_initialize.<locals>.Context | None = None, func_name: str | None = None, is_public: bool = True) FreeFuncKernelBuilder¶
Short-cut to create a new module with a single function in one shot.
- yield_results(*results: Value)¶
Yields results of the kernel computation.
- class wave_lang.runtime.op_reg.IntArg(v: int)¶
- generate_meta() int¶
- ir_arity: int¶
- is_list: bool = False¶
- maybe_tensor_value: Tensor | None = None¶
- property mlir_type_asm: str¶
- property spec_key: str¶
Generates a key that will be the same for all specializations.
- spec_value: Any | None¶
- v¶
- class wave_lang.runtime.op_reg.KernelBuilder(ksel: KernelSelection, arg_bindings: list[Value | list[Value]], *, ip: InsertionPoint, module_body: Block, symbol_table: SymbolTable)¶
Support class for building a kernel.
- arg_value(index: int) list[Value] | Value¶
Gets the concrete IR Value for the argument at index.
This will assert if the corresponding argument was set as ir_arity=0 during kernel selection.
- constant_index(i: int) Value¶
Builds a constant index value.
- abstract yield_results(*results: Value)¶
Yields results of the kernel computation.
- class wave_lang.runtime.op_reg.KernelSelection(op: CustomOp, arg_arity: int)¶
Represents a selected kernel based on a concrete signature.
The CustomOp.select method must yield an instance of this, and it will be done for every invocation. At this point, the kernel has not yet been generated, but we have selected a generation strategy based on a concrete signature.
KernelSelection implements a strategy pattern for reading concrete values from inputs to match the kernel’s expected signature. This enables the same Python-based MLIR generation logic to work with KernelSelection instances created from different contexts:
Eager execution (reads directly from PyTorch tensors/values)
AOT compilation (converts MLIR types back to PyTorch tensors first)
This mechanism also serves as the means for servicing meta registrations because it implicitly computes everything needed (i.e. shapes, etc).
- arg_descs¶
- abstract arg_int(arg: int) IntArg¶
Declares an argument to be an integer value that can take any value.
Returns the argument descriptor, which can be used to further inspect or constrain the selection.
- abstract arg_optional_tensor(arg: int) TensorArg | None¶
Declares an optional tensor argument.
Returns None if the argument was not provided at the call site, and a TensorArg if it was.
- abstract arg_tensor(arg: int, *, inplace_tied: bool = False) TensorArg¶
Declares an argument to allow any ranked tensor and to specialize for each rank and dtype.
Returns the argument descriptor, which can be used to further inspect or constrain the selection. It will default to allowing all dimensions to be dynamic.
If inplace_tied is True, then this argument participates in in-place semantics. The kernel must yield the result-mutated after all normal results in the order declared.
- abstract arg_tensor_list(arg: int) TensorListArg¶
Declares an argument to accept a list of tensors which will be specialized for the list size and each rank/dtype.
Returns the argument descriptor, which can be used to further inspect or constrain the selection. It will default to allowing all dimensions to be dynamic.
- abstract attr_float(arg: int) AttrArg¶
Declares an argument to be a float attribute.
Such arguments are not materialized in the IR as Values but may be used to generate the IR. In AOT contexts, they must be derived from static values.
- abstract attr_int(arg: int) AttrArg¶
Declares an argument to be an integer attribute.
Such arguments are not materialized in the IR as Values but may be used to generate the IR. In AOT contexts, they must be derived from static values.
- abstract attr_list_float(arg: int) AttrArg¶
Declares an argument to be a list<float> attribute.
Such arguments are not materialized in the IR as Values but may be used to generate the IR. In AOT contexts, they must be derived from static values.
- abstract attr_list_int(arg: int) AttrArg¶
Declares an argument to be a list<integer> attribute.
Such arguments are not materialized in the IR as Values but may be used to generate the IR. In AOT contexts, they must be derived from static values.
- abstract attr_str(arg: int) AttrArg¶
Declares an argument to be a string attribute.
Such arguments are not materialized in the IR as Values but may be used to generate the IR. In AOT contexts, they must be derived from static values.
- generate_meta_returns() Any¶
- get_provided_arg_descs() list[AttrArg | IntArg | TensorArg | TensorListArg | EmptyOptionalTensorArg | None]¶
Returns argument descriptors with empty optional arguments filtered out.
- inplace_tied_arg_descs: list[AttrArg | IntArg | TensorArg | TensorListArg | EmptyOptionalTensorArg]¶
- op¶
- return_new_tensor(size: list, dtype: dtype) TensorArg¶
Constructs a new symbolic tensor and marks the next result as returning it.
This delegates to return_tensor but takes care of some easy to mess up boiler plate for dynamic shapes.
- abstract return_tensor(t: Tensor) TensorArg¶
Marks the next return value as a Tensor.
By default, it will be rank and dtype specialized but have completely dynamic dimensions. Dimensions can be further constrained by modifying the returned descriptor.
- property spec_key: str¶
- variant: str¶
- class wave_lang.runtime.op_reg.TensorArg(t: Tensor)¶
- generate_meta() Tensor¶
- ir_arity: int = 1¶
- is_list: bool = False¶
- maybe_tensor_value: Tensor¶
- property mlir_type_asm: str¶
- spec_dims¶
- property spec_key: str¶
Generates a key that will be the same for all specializations.
- specialize_all_dims()¶
Marks all dimensions as specialized.
- specialize_dims(*indices: int)¶
Specializes individual dimensions.
i can have negative indexing.
- t¶
- wave_lang.runtime.op_reg.def_library(ns) Library¶
Creates a new ‘DEF’ library which contains custom ops.
It is necessary to create such custom op libraries in this way since the library is registered with the compiler in such a way that it can operate over all known custom ops.