diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 34fd649a5d13..11fa7ae12e67 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2581,7 +2581,7 @@ def unsafe_set_dtype(self, block: Union[BlockRV, str], buffer_index: int, dtype: Examples -------- - Before set_dtype, in TensorIR, the IR is: + Before unsafe_set_dtype, in TensorIR, the IR is: .. code-block:: python @@ -2600,12 +2600,12 @@ def before_set_dtype( vi, vj = T.axis.remap("SS", [i, j] C[vi, vj] = B[vi, vj] + 1.0 - Create the schedule and do set_dtype: + Create the schedule and do unsafe_set_dtype: .. code-block:: python sch = tir.Schedule(before_set_dtype) - sch.set_dtype("B", buffer_index=0, dtype="float16") + sch.unsafe_set_dtype("B", buffer_index=0, dtype="float16") print(sch.mod["main"].script()) After applying set_dtype, the IR becomes: @@ -2629,7 +2629,8 @@ def after_set_dtype( Note ---- - `set_dtype` requires the buffer to be an intermediate buffer defined via `alloc_buffer`. + `unsafe_set_dtype` requires the buffer to be an intermediate buffer defined via + `alloc_buffer`. """ block = self._normalize_block_arg(block) _ffi_api.ScheduleUnsafeSetDType( # type: ignore # pylint: disable=no-member