wave_lang.support¶
- wave_lang.support.default_trace_tensor_callback(key: str, tensor: Tensor)¶
- wave_lang.support.get_location_capture_config()¶
Get LocationCaptureConfig based on the current debug flags.
- wave_lang.support.trace_tensor_callback(key: str, tensor: Tensor)¶
- wave_lang.support.trace_tensor_to_npy(key: str, tensor: Tensor)¶
conversions¶
- wave_lang.support.conversions.dtype_to_element_type(dtype) HalElementType¶
- wave_lang.support.conversions.torch_dtype_to_numpy(torch_dtype: dtype) Any¶
- wave_lang.support.conversions.torch_dtyped_shape_to_iree_format(shape_or_tensor: Sequence[int] | Tensor, /, dtype: dtype | None = None) str¶
Example: shape = [1, 2, 3] dtype = torch.bfloat16 Returns “1x2x3xbf16”
debugging¶
- class wave_lang.support.debugging.DebugFlags(log_level: int = 30, asserts: bool = False, runtime_trace_dir: str | None = None, location_level: wave_lang.support.location_config.LocationCaptureLevel | None = None)¶
- asserts: bool = False¶
- location_level: LocationCaptureLevel | None = None¶
- log_level: int = 30¶
- static parse(settings: str) DebugFlags¶
- static parse_from_env() DebugFlags¶
- runtime_trace_dir: str | None = None¶
- set(part: str)¶
logging¶
- class wave_lang.support.logging.DefaultFormatter¶
- wave_lang.support.logging.get_logger(name: str)¶