wave_lang.runtime

class wave_lang.runtime.Device(uri: str | None = None, *, device_state: DeviceState | None = None)

Represents a low-level device (HalDriver/HalDevice) and scheduling data.

This is the type that user’s interact with as a ‘Device’. Devices can be handled loose-leaf or bound to a thread with a context manager.

clear()

Clears the current device without a context manager.

compile_target_flags: tuple[str, ...]
create_hal_module() VmModule
static current() Device
driver_id: str
dump_device_info() str
export_torch_tensor: Callable[[HalBufferView, Tensor], Tensor]
finalize_iree_action(external_timepoint: HalExternalTimepoint)
get_type_key_hash(*, hasher: ~typing.Callable[[str], str] = <function <lambda>>)
property hal_device: HalDevice
import_torch_tensor: Callable[[Tensor], HalBufferView]
instance_cache_key: str
set() Device

Sets this device as the current device without a context manager.

setup_iree_action()
property sync: bool
type_cache_key: str
property vm_instance: VmInstance
class wave_lang.runtime.DeviceState(*, driver: str | HalDriver, device: HalDevice | None = None, vm_instance: VmInstance | None = None, enumerated_info: dict | None = None, torch_device: device | None = None, torch_stream: int | None = None, dlpack_device_type_code: int = 0)

State for an instantiated HAL device.

Note that the IREE runtime internally manages a global cache of drivers for standard named-access (not custom-constructed) drivers.

device
dlpack_device_type_code
driver
property enumerated_device_id: int
enumerated_info
property enumerated_name: str
property enumerated_path: str
static from_uri(uri: str) DeviceState
instance
torch_device
torch_stream
class wave_lang.runtime.Launchable(loader: Callable[[Device], Tuple[str, VmModule]] | None, parameter_providers: Sequence[ParameterProvider] = (), is_async: bool = True)

Facilities for launching a compiled program (VMFB) on an attached device.

Like the eager custom-op executor, this follows the usual PyTorch rules whereby the device that input tensors reside on dictates where the launch happens. Unlike that flow, this does not include any notion of jitting or caching. It also has APIs for using parameters, etc.

You must manage all compilation/target settings yourself and you merely assert that a given binary is appropriate for launch on a device type. This has various limitations.

static from_file_cache_only(file_cache_dir: str | Path, *, parameter_providers: Sequence[ParameterProvider] = (), entry_point: str = 'main$async') Launchable

Only loads vmfbs from the provided file_cache_dir. Will raise an error if not found.

static from_vm_module(vm_module_callback: Callable[[Device], VmModule], *, parameter_providers: Sequence[ParameterProvider] = (), entry_point: str = 'main$async')
static jit_compile(source: Any, *, parameter_providers: Sequence[ParameterProvider] = (), entry_point: str = 'main$async', file_cache_dir: str | Path | None = None) Launchable

Generates a launchable from a program source (e.g., mlir string). Set a file_cache_dir to enable storing/retrieving artifacts between sessions.

preload(device: device)

Pre-loads (or JIT compiles) for the given torch.device.

wave_lang.runtime.get_vm_instance() VmInstance
wave_lang.runtime.invoke_vm_function(device: ~wave_lang.runtime.device.Device, is_async: bool, vm_context: ~iree._runtime_libs._runtime.VmContext, vm_function: ~iree._runtime_libs._runtime.VmFunction, arg_list: ~iree._runtime_libs._runtime.VmVariantList, ret_list: ~iree._runtime_libs._runtime.VmVariantList, *, timer: ~typing.Callable[[], float] = <function <lambda>>)

Invokes a vm function on a device, adding async fences to the arg_list if is_async.

No checks are made to ensure compatibility between the provided device and vm_function. A timer function (float return) may be provided, and this function will return the invocation time.

op_reg