wave_lang.support

wave_lang.support.default_trace_tensor_callback(key: str, tensor: Tensor)
wave_lang.support.get_location_capture_config()

Get LocationCaptureConfig based on the current debug flags.

wave_lang.support.trace_tensor_callback(key: str, tensor: Tensor)
wave_lang.support.trace_tensor_to_npy(key: str, tensor: Tensor)

conversions

wave_lang.support.conversions.dtype_to_element_type(dtype) HalElementType
wave_lang.support.conversions.torch_dtype_to_numpy(torch_dtype: dtype) Any
wave_lang.support.conversions.torch_dtyped_shape_to_iree_format(shape_or_tensor: Sequence[int] | Tensor, /, dtype: dtype | None = None) str

Example: shape = [1, 2, 3] dtype = torch.bfloat16 Returns “1x2x3xbf16”

debugging

class wave_lang.support.debugging.DebugFlags(log_level: int = 30, asserts: bool = False, runtime_trace_dir: str | None = None, location_level: wave_lang.support.location_config.LocationCaptureLevel | None = None)
asserts: bool = False
location_level: LocationCaptureLevel | None = None
log_level: int = 30
static parse(settings: str) DebugFlags
static parse_from_env() DebugFlags
runtime_trace_dir: str | None = None
set(part: str)

logging

class wave_lang.support.logging.DefaultFormatter
wave_lang.support.logging.get_logger(name: str)

tools