Skip to content

Commit

Permalink
[TIR][Schedule] Improve blockize to support blockizing multiple blocks (
Browse files Browse the repository at this point in the history
apache#14766)

* Improve blockize to support blockize multiple blocks

* Adjust unit test to match simplified blockize result.

* Update doc

* Preserve unit iters in expr and revert test case change

* Apply review suggestion

---------

Co-authored-by: Min Chen <[email protected]>
  • Loading branch information
multiverstack-intellif and Min Chen authored May 16, 2023
1 parent f6bbe94 commit 2f863dd
Show file tree
Hide file tree
Showing 12 changed files with 373 additions and 43 deletions.
7 changes: 7 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,13 @@ class ScheduleNode : public runtime::Object {
* \return the new block
*/
virtual BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters = true) = 0;
/*!
* \brief Convert specified blocks into a nested block.
* \param blocks the specified block to construct the new block
* \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
* \return the new block
*/
virtual BlockRV Blockize(const Array<BlockRV>& blocks, bool preserve_unit_iters = true) = 0;
/*!
* \brief Tensorize the computation enclosed by loop with the tensor intrin.
* \param loop_rv The loop to be tensorized
Expand Down
12 changes: 7 additions & 5 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2691,13 +2691,15 @@ def after_set_dtype(
########## Schedule: Blockize & Tensorize ##########

@type_checked
def blockize(self, loop: LoopRV, preserve_unit_iters: bool = True) -> BlockRV:
"""Convert the subtree rooted at a specific loop into a block.
def blockize(
self, target: Union[LoopRV, List[BlockRV]], preserve_unit_iters: bool = True
) -> BlockRV:
"""Convert multiple blocks or the subtree rooted at a specific loop into a block.
Parameters
----------
loop : LoopRV
The root of the subtree.
target : LoopRV or List[BlockRV]
The root of the subtree or the specified blocks.
preserve_unit_iters : bool
Whether or not to preserve unit iterators in block bindings
Expand Down Expand Up @@ -2764,7 +2766,7 @@ def after_blockize(
block are divisible by the subspace represented by the loops starting at the given loop.
"""

return _ffi_api.ScheduleBlockize(self, loop, preserve_unit_iters) # type: ignore # pylint: disable=no-member
return _ffi_api.ScheduleBlockize(self, target, preserve_unit_iters) # type: ignore # pylint: disable=no-member

@type_checked
def tensorize(
Expand Down
9 changes: 9 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,15 @@ BlockRV ConcreteScheduleNode::Blockize(const LoopRV& loop_rv, bool preserve_unit
return CreateRV<BlockRV>(result);
}

BlockRV ConcreteScheduleNode::Blockize(const Array<BlockRV>& blocks, bool preserve_unit_iters) {
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
result = tir::Blockize(state_, this->GetSRefs(blocks), preserve_unit_iters);
this->state_->DebugVerify();
TVM_TIR_SCHEDULE_END("blockize", this->error_render_level_);
return CreateRV<BlockRV>(result);
}

void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin,
bool preserve_unit_iters) {
TVM_TIR_SCHEDULE_BEGIN();
Expand Down
1 change: 1 addition & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ class ConcreteScheduleNode : public ScheduleNode {
void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) override;
/******** Schedule: Blockize & Tensorize ********/
BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) override;
BlockRV Blockize(const Array<BlockRV>& blocks, bool preserve_unit_iters) override;
void Tensorize(const BlockRV& block_rv, const String& intrin, bool preserve_unit_iters) override;
void Tensorize(const LoopRV& loop_rv, const String& intrin, bool preserve_unit_iters) override;
/******** Schedule: Annotation ********/
Expand Down
10 changes: 10 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,16 @@ TVM_DLL void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, in
*/
TVM_DLL StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref, bool preserve_unit_iters);

/*!
* \brief Convert specific blocks into a nested block.
* \param self The state of the schedule
* \param blocks The target blocks to construct the new block
* \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
* \return The new block
*/
TVM_DLL StmtSRef Blockize(ScheduleState self, const Array<StmtSRef>& blocks,
bool preserve_unit_iters);

/*!
* \brief Tensorize the computation enclosed by loop with the tensor intrinsic.
* \param self The state of the schedule
Expand Down
Loading

0 comments on commit 2f863dd

Please sign in to comment.