Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,13 @@ class ScheduleNode : public runtime::Object {
* \param block The block to be inlined to its producer
*/
virtual void ReverseComputeInline(const BlockRV& block) = 0;
/*!
* \brief Fuse an epilogue block into a reduction block
* \param reduction_block The reduction block (e.g., matmul)
* \param epilogue_block The epilogue block to be fused (e.g., bias add)
*/
virtual void FuseReductionEpilogue(const BlockRV& reduction_block,
const BlockRV& epilogue_block) = 0;
/******** Schedule: Reduction ********/
/*!
* \brief Decompose a reduction block into two separate blocks.
Expand Down
27 changes: 27 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2345,6 +2345,33 @@ def after_inline(a: T.handle, c: T.handle) -> None:
# pylint: disable-next=no-member
_ffi_api.ScheduleReverseComputeInline(self, block) # type: ignore

@type_checked
def fuse_reduction_epilogue(
self,
reduction_block: Union[BlockRV, str],
epilogue_block: Union[BlockRV, str],
) -> None:
"""Fuse an epilogue block into a reduction block.

It requires:
1) The reduction block is a complete reduction block
2) The epilogue block only reads from the reduction block's output
3) The epilogue performs a simple addition: output = reduction_result + bias

Parameters
----------
reduction_block : Union[BlockRV, str]
The reduction block (e.g., matmul)
epilogue_block : Union[BlockRV, str]
The epilogue block to be fused (e.g., bias add)
"""
reduction_block = self._normalize_block_arg(reduction_block)
epilogue_block = self._normalize_block_arg(epilogue_block)
# pylint: disable-next=no-member
_ffi_api.ScheduleFuseReductionEpilogue(
self, reduction_block, epilogue_block
) # type: ignore

########## Schedule: Reduction ##########

@type_checked
Expand Down
9 changes: 9 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,15 @@ void ConcreteScheduleNode::ReverseComputeInline(const BlockRV& block_rv) {
this->state_->DebugVerify();
}

void ConcreteScheduleNode::FuseReductionEpilogue(const BlockRV& reduction_block_rv,
const BlockRV& epilogue_block_rv) {
TVM_TIR_SCHEDULE_BEGIN();
tir::FuseReductionEpilogue(state_, this->GetSRef(reduction_block_rv),
this->GetSRef(epilogue_block_rv));
TVM_TIR_SCHEDULE_END("fuse-reduction-epilogue", this->error_render_level_);
this->state_->DebugVerify();
}

/******** Schedule: Block Annotation ********/

void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, int axis,
Expand Down
2 changes: 2 additions & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ class ConcreteScheduleNode : public ScheduleNode {
int index = -1) override;
void ComputeInline(const BlockRV& block) override;
void ReverseComputeInline(const BlockRV& block) override;
void FuseReductionEpilogue(const BlockRV& reduction_block,
const BlockRV& epilogue_block) override;
/******** Schedule: Reduction ********/
BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override;
BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) override;
Expand Down
8 changes: 8 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,14 @@ TVM_DLL void ComputeInline(ScheduleState self, const StmtSRef& block_sref);
* \param block_sref The sref to the block to be inlined to its producer
*/
TVM_DLL void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref);
/*!
* \brief Fuse an epilogue block into a reduction block
* \param self The state of the schedule
* \param reduction_block_sref The sref to the reduction block
* \param epilogue_block_sref The sref to the epilogue block to be fused
*/
TVM_DLL void FuseReductionEpilogue(ScheduleState self, const StmtSRef& reduction_block_sref,
const StmtSRef& epilogue_block_sref);
/******** Schedule: Reduction ********/
/*!
* \brief Decompose a reduction block into two separate blocks.
Expand Down
Loading