diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 69f05201177b..187d0a31d05d 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -641,6 +641,13 @@ class ScheduleNode : public runtime::Object { * \return the new block */ virtual BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters = true) = 0; + /*! + * \brief Convert specified blocks into a nested block. + * \param blocks the specified block to construct the new block + * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings + * \return the new block + */ + virtual BlockRV Blockize(const Array& blocks, bool preserve_unit_iters = true) = 0; /*! * \brief Tensorize the computation enclosed by loop with the tensor intrin. * \param loop_rv The loop to be tensorized diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 7c7af998bef2..8ebc02ccbb20 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2691,13 +2691,15 @@ def after_set_dtype( ########## Schedule: Blockize & Tensorize ########## @type_checked - def blockize(self, loop: LoopRV, preserve_unit_iters: bool = True) -> BlockRV: - """Convert the subtree rooted at a specific loop into a block. + def blockize( + self, target: Union[LoopRV, List[BlockRV]], preserve_unit_iters: bool = True + ) -> BlockRV: + """Convert multiple blocks or the subtree rooted at a specific loop into a block. Parameters ---------- - loop : LoopRV - The root of the subtree. + target : LoopRV or List[BlockRV] + The root of the subtree or the specified blocks. preserve_unit_iters : bool Whether or not to preserve unit iterators in block bindings @@ -2764,7 +2766,7 @@ def after_blockize( block are divisible by the subspace represented by the loops starting at the given loop. """ - return _ffi_api.ScheduleBlockize(self, loop, preserve_unit_iters) # type: ignore # pylint: disable=no-member + return _ffi_api.ScheduleBlockize(self, target, preserve_unit_iters) # type: ignore # pylint: disable=no-member @type_checked def tensorize( diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 7192a4809994..d48512724214 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -791,6 +791,15 @@ BlockRV ConcreteScheduleNode::Blockize(const LoopRV& loop_rv, bool preserve_unit return CreateRV(result); } +BlockRV ConcreteScheduleNode::Blockize(const Array& blocks, bool preserve_unit_iters) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::Blockize(state_, this->GetSRefs(blocks), preserve_unit_iters); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("blockize", this->error_render_level_); + return CreateRV(result); +} + void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin, bool preserve_unit_iters) { TVM_TIR_SCHEDULE_BEGIN(); diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 16065df3cd93..73a0b314dd84 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -153,6 +153,7 @@ class ConcreteScheduleNode : public ScheduleNode { void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) override; /******** Schedule: Blockize & Tensorize ********/ BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) override; + BlockRV Blockize(const Array& blocks, bool preserve_unit_iters) override; void Tensorize(const BlockRV& block_rv, const String& intrin, bool preserve_unit_iters) override; void Tensorize(const LoopRV& loop_rv, const String& intrin, bool preserve_unit_iters) override; /******** Schedule: Annotation ********/ diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 78d1cab05ce3..7355d38db1a0 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -542,6 +542,16 @@ TVM_DLL void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, in */ TVM_DLL StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref, bool preserve_unit_iters); +/*! + * \brief Convert specific blocks into a nested block. + * \param self The state of the schedule + * \param blocks The target blocks to construct the new block + * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings + * \return The new block + */ +TVM_DLL StmtSRef Blockize(ScheduleState self, const Array& blocks, + bool preserve_unit_iters); + /*! * \brief Tensorize the computation enclosed by loop with the tensor intrinsic. * \param self The state of the schedule diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index 25694ed6fc49..994a3a95fbad 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -139,20 +139,26 @@ Array> TrivialSubspaceDivision(const Array& iter /*! * \brief Subspace division. The space is divided into two subspaces: + * If loop_sref_as_outer is false: * 1. The subspace represented by the outer loops above `loop_sref` (exclusive). * 2. The subspace represented by the inner loops below `loop_sref` (inclusive). + * else: + * 1. The subspace represented by the outer loops above `loop_sref` (inclusive). + * 2. The subspace represented by the inner loops below `loop_sref` (exclusive). * \param realize The inner block * \param block_sref The sref to the inner block * \param loop_sref The loop that is the root of the second subspace. * \param loops The loops that represents the second part of the subspace. * \param analyzer The arithmetic analyzer to use. * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings + * \param loop_sref_as_outer Whether loop_sref is divided into outer or inner */ Array> SubspaceDivide(const BlockRealize& realize, const StmtSRef& block_sref, // const StmtSRef& loop_sref, // std::vector* loops, - arith::Analyzer* analyzer, bool preserve_unit_iters) { + arith::Analyzer* analyzer, bool preserve_unit_iters, + bool loop_sref_as_outer = false) { Array inner_vars; Array outer_vars; Map loop_var_domain; @@ -168,7 +174,7 @@ Array> SubspaceDivide(const BlockRealize& realize, outer_vars.push_back(loop->loop_var); } loop_var_domain.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); - if (sref == loop_sref.get()) { + if ((loop_sref_as_outer && sref->parent == loop_sref.get()) || sref == loop_sref.get()) { inner = false; } } @@ -201,12 +207,14 @@ Map DeriveBlockBinding(const Array& iter_vars, Array* outer_iter_vars, // Array* outer_bindings, // Array* inner_iter_vars, // - Array* inner_bindings, bool preserve_unit_iters) { + Array* inner_bindings, // + bool preserve_unit_iters, bool reuse_outer = false) { using arith::IterMapExpr; using arith::IterMapExprNode; using arith::NormalizeIterMapToExpr; Map block_var_subst; ICHECK_EQ(iter_vars.size() + 1, division.size()); + arith::Analyzer ana; for (int i = 0, n = iter_vars.size(); i < n; ++i) { const IterVar& iter_var = iter_vars[i]; arith::IterMark outer_mark = division[i][0]; @@ -219,30 +227,43 @@ Map DeriveBlockBinding(const Array& iter_vars, // The inner block will have binding: iter_inner -> inner_binding // The iter in the original block will be substituted with base + iter_inner where // base == iter_outer * iter_inner_extent - if (is_one(inner_mark->extent)) { // IsOuter - // extract this iter var to outer block directly + // create iter var for the outer block + IterVar outer_iter; + if (reuse_outer) { + outer_iter = outer_iter_vars->operator[](i); + ICHECK(ana.CanProveEqual(outer_iter->dom->extent, outer_mark->extent)); + ICHECK( + ana.CanProveEqual(outer_bindings->operator[](i), NormalizeIterMapToExpr(outer_binding))); + } else { + outer_iter = IterVar(/*dom=*/RangeFromExtent(outer_mark->extent), + /*var=*/iter_var->var.copy_with_suffix("_o"), + /*iter_type=*/iter_var->iter_type); outer_bindings->push_back(NormalizeIterMapToExpr(outer_binding)); - outer_iter_vars->push_back(iter_var); - continue; + outer_iter_vars->push_back(outer_iter); } - // create iter var for the outer block - IterVar outer_iter(/*dom=*/RangeFromExtent(outer_mark->extent), - /*var=*/iter_var->var.copy_with_suffix("_o"), - /*iter_type=*/iter_var->iter_type); - outer_bindings->push_back(NormalizeIterMapToExpr(outer_binding)); - outer_iter_vars->push_back(outer_iter); - // create iter var for the inner block - IterVar inner_iter(/*dom=*/RangeFromExtent(inner_mark->extent), - /*var=*/iter_var->var.copy_with_suffix("_i"), - /*iter_type=*/iter_var->iter_type); - inner_bindings->push_back(NormalizeIterMapToExpr(inner_binding)); - inner_iter_vars->push_back(inner_iter); - // substitution PrimExpr sub{nullptr}; - if (is_one(outer_mark->extent)) { - sub = inner_iter->var; + if (is_one(inner_mark->extent)) { + // Skip inner var when extent is 1 + // substitution + if (is_one(outer_mark->extent) && !preserve_unit_iters) { + // Simplify outer if not preserve_unit_iters + sub = make_zero(outer_mark->extent.dtype()); + } else { + sub = outer_iter; + } } else { - sub = outer_iter * inner_mark->extent + inner_iter->var; + // create iter var for the inner block + IterVar inner_iter(/*dom=*/RangeFromExtent(inner_mark->extent), + /*var=*/iter_var->var.copy_with_suffix("_i"), + /*iter_type=*/iter_var->iter_type); + inner_bindings->push_back(NormalizeIterMapToExpr(inner_binding)); + inner_iter_vars->push_back(inner_iter); + // substitution + if (is_one(outer_mark->extent)) { + sub = inner_iter->var; + } else { + sub = outer_iter * inner_mark->extent + inner_iter->var; + } } block_var_subst.Set(iter_var->var, sub); } @@ -414,6 +435,37 @@ Array EvalSetRegions(const Array& regions, return results; } +/*! + * \brief Get the union of the given regions + * \param regions The input regions for the union. + * \return The union regions + */ +Array UnionRegions(const Array& regions) { + typedef std::vector> ranges_t; + std::unordered_map intset_map; + for (const BufferRegion& buffer_region : regions) { + const Buffer& buffer = buffer_region->buffer; + if (intset_map.find(buffer) == intset_map.end()) { + intset_map[buffer] = {buffer->shape.size(), Array()}; + } + std::vector> dim_range(buffer->shape.size(), Array()); + for (size_t dim = 0; dim < buffer->shape.size(); ++dim) { + intset_map[buffer][dim].push_back(arith::IntSet::FromRange(buffer_region->region[dim])); + } + } + Array results; + for (const auto& it : intset_map) { + const Buffer& buffer = it.first; + Array regions; + for (size_t dim = 0; dim < buffer->shape.size(); ++dim) { + const arith::IntSet intset = arith::Union(it.second[dim]); + regions.push_back({intset.min(), intset.max() + 1}); + } + results.push_back(BufferRegion(buffer, regions)); + } + return results; +} + /*! * \brief Create the loop nest on top of the given stmt. * \param stmt The stmt to be wrapped. @@ -513,6 +565,181 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref, bool preserve_u return result; } +BlockRealize BlockizeBlocks(const ScheduleState& self, const Array& block_srefs, + const StmtSRef& lca, Map* block_sref_reuse, + bool preserve_unit_iters) { + Array seq_body; + PrimExpr outer_predicate{nullptr}; + Array outer_iter_vars{nullptr}; + Array outer_bindings{nullptr}; + Array read_regions; + Array write_regions; + std::string outer_block_name = "outer_"; + Map loop_var_subst; + arith::Analyzer analyzer; + for (const auto& block_sref : block_srefs) { + auto block_realize = GetBlockRealize(self, block_sref); + auto block = block_realize->block; + // Step 1: Derive subspace division + std::vector loops; + Array> division = SubspaceDivide(block_realize, block_sref, lca, &loops, + &analyzer, preserve_unit_iters, true); + if (division.empty()) { + throw SubspaceNotDivisibleError(self->mod, GetRef(loops.back()), block); + } + outer_predicate = division.back()[0]->extent; + PrimExpr inner_predicate = division.back()[1]->extent; + // Step 2. Derive block bindings for both outer and inner block. + Array inner_iter_vars; + Array inner_bindings; + Map block_var_subst = // + DeriveBlockBinding(block->iter_vars, division, // + &outer_iter_vars, &outer_bindings, // + &inner_iter_vars, &inner_bindings, // + preserve_unit_iters, outer_iter_vars.defined()); + // Step 3: Do var substitution to adjust to the new block bindings + for (size_t i = 0; i < outer_iter_vars.size(); ++i) { + if (outer_bindings[i].as()) { + loop_var_subst.Set(Downcast(outer_bindings[i]), outer_iter_vars[i]->var); + } + } + Map inner_iter_dom; + for (const IterVar& iter : inner_iter_vars) { + Range dom = Substitute(iter->dom, loop_var_subst); + inner_iter_dom.Set(iter->var, arith::IntSet::FromRange(dom)); + analyzer.Bind(iter->var, dom); + } + Block block_subst = + Downcast(Substitute(block, block_var_subst, block_sref_reuse, &analyzer)); + auto reads = EvalSetRegions(block_subst->reads, inner_iter_dom); + auto writes = EvalSetRegions(block_subst->writes, inner_iter_dom); + read_regions.insert(read_regions.end(), reads.begin(), reads.end()); + write_regions.insert(write_regions.end(), writes.begin(), writes.end()); + outer_block_name += block_subst->name_hint + "_"; + // Step 4: Generate the inner block. No reduction iter vars allowed for the outer loops. + bool has_outer_reduction = false; + if (block_subst->init.defined()) { + for (const IterVar& iter_var : outer_iter_vars) { + if (iter_var->iter_type == kCommReduce) { + has_outer_reduction = true; + break; + } + } + } + ICHECK(has_outer_reduction == false) + << "No reduction iter vars allowed for the outer loops when blockize multiple blocks"; + BlockRealize inner_realize = GenerateInner(/*is_write_reduction=*/has_outer_reduction, + /*iter_vars=*/inner_iter_vars, + /*iter_values*/ inner_bindings, + /*predicate=*/inner_predicate, + /*block=*/block_subst); + block_sref_reuse->Set(block, inner_realize->block); + Stmt stmt = inner_realize; + for (const ForNode* loop : loops) { + ObjectPtr new_loop = make_object(*loop); + new_loop->body = std::move(stmt); + new_loop->extent = Substitute(new_loop->extent, loop_var_subst); + stmt = For(new_loop); + } + seq_body.push_back(stmt); + } + // Step 5: Generate the outer block. + return BlockRealize( + /*iter_values=*/std::move(outer_bindings), + /*predicate=*/std::move(outer_predicate), + /*block=*/ + Block(/*iter_vars=*/std::move(outer_iter_vars), + /*reads=*/UnionRegions(read_regions), + /*writes=*/UnionRegions(write_regions), + /*name_hint=*/outer_block_name, + /*body=*/SeqStmt(seq_body), + /*init=*/Optional(NullOpt))); +} + +class BlockizeRewriter : public StmtMutator { + public: + static Stmt Rewrite(const StmtSRef& lca, const Array& blocks, + const BlockRealize& blockized) { + BlockizeRewriter rewriter(lca, blocks, blockized); + return rewriter(GetRef(lca->stmt)); + } + + private: + explicit BlockizeRewriter(const StmtSRef& lca, const Array& blocks, + const BlockRealize& blockized) + : lca_(lca), blocks_(blocks), blockized_(blockized) {} + + Stmt RewriteSeq(const Stmt& stmt) { + const SeqStmtNode* seq = stmt.as(); + ICHECK(seq) << "Target blocks must not be nested with each other!"; + int idx_start = -1; + int found_cnt = 0; + int last_found_idx = -1; + size_t cur_idx = 0; + Array new_seq; + for (const Stmt& it : seq->seq) { + target_in_ = false; + Stmt stmt = StmtMutator::VisitStmt(it); + if (target_in_) { + if (idx_start == -1) { + idx_start = cur_idx; + new_seq.push_back(blockized_); + } else { + ICHECK_EQ(last_found_idx, cur_idx - 1) << "Target blocks must be consecutive!"; + } + last_found_idx = cur_idx; + ++found_cnt; + } else { + new_seq.push_back(it); + } + ++cur_idx; + } + if (new_seq.size() == 1) return new_seq[0]; + return SeqStmt(new_seq, seq->span); + } + + Stmt VisitStmt_(const ForNode* loop) final { + if (loop == lca_->stmt) { + return For(loop->loop_var, loop->min, loop->extent, loop->kind, RewriteSeq(loop->body), + loop->thread_binding, loop->annotations, loop->span); + } + return StmtMutator::VisitStmt_(loop); + } + + Stmt VisitStmt_(const BlockNode* block) final { + if (block == lca_->stmt) { + return Block(block->iter_vars, block->reads, block->writes, block->name_hint, + RewriteSeq(block->body), block->init, block->alloc_buffers, block->match_buffers, + block->annotations, block->span); + } + for (const StmtSRef& block_sref : blocks_) { + if (block_sref->stmt == block) { + target_in_ = true; + break; + } + } + return GetRef(block); + } + + StmtSRef lca_; + Array blocks_; + BlockRealize blockized_; + bool target_in_ = false; +}; + +StmtSRef Blockize(ScheduleState self, const Array& blocks, bool preserve_unit_iters) { + Map block_sref_reuse; + auto lca = GetSRefLowestCommonAncestor(blocks); + BlockRealize blockized = + BlockizeBlocks(self, blocks, lca, &block_sref_reuse, preserve_unit_iters); + auto new_root = BlockizeRewriter::Rewrite(lca, blocks, blockized); + self->Replace(lca, new_root, block_sref_reuse); + StmtSRef result = self->stmt2ref.at(blockized->block.get()); + StmtSRef scope_root = tir::GetScopeRoot(self, result, /*require_stage_pipeline=*/false); + self->UpdateScopeBlockInfo(tir::GetBlockRealize(self, scope_root)); + return result; +} + void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& intrin, bool preserve_unit_iters) { // Step 1: Blockize the subtree rooted at the given loop if needed @@ -636,13 +863,19 @@ struct BlockizeTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 1; static constexpr size_t kNumDecisions = 0; - static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Bool preserve_unit_iters) { - return sch->Blockize(loop_rv, preserve_unit_iters.operator bool()); + static BlockRV UnpackedApplyToSchedule(Schedule sch, ObjectRef target, Bool preserve_unit_iters) { + if (auto loop = target.as()) { + return sch->Blockize(loop.value(), preserve_unit_iters.operator bool()); + } else if (auto blocks = target.as>()) { + return sch->Blockize(blocks.value(), preserve_unit_iters.operator bool()); + } + LOG(FATAL) << "TypeError: expect Loop or list of Blocks, but gets:" << target->GetTypeKey(); } - static String UnpackedAsPython(Array outputs, String loop_rv, Bool preserve_unit_iters) { + static String UnpackedAsPython(Array outputs, ObjectRef target, + Bool preserve_unit_iters) { PythonAPICall py("blockize"); - py.Input("loop", loop_rv); + py.Input("target", target); py.Input("preserve_unit_iters", preserve_unit_iters.operator bool()); py.SingleOutput(outputs); return py.Str(); diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 8663ac2b9736..56d0d1efa906 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -228,7 +228,14 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnsafeSetDType") .set_body_method(&ScheduleNode::UnsafeSetDType); /******** (FFI) Blockize & Tensorize ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize") - .set_body_method(&ScheduleNode::Blockize); + .set_body_typed([](Schedule self, ObjectRef target, bool preserve_unit_iters) { + if (auto loop_rv = target.as()) { + return self->Blockize(loop_rv.value(), preserve_unit_iters); + } else if (auto blocks = target.as>()) { + return self->Blockize(blocks.value(), preserve_unit_iters); + } + LOG(FATAL) << "Unsupported target type: " << target->GetTypeKey(); + }); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTensorize") .set_body_typed([](Schedule self, ObjectRef rv, String intrin, bool preserve_unit_iters) { if (auto block_rv = rv.as()) { diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 4d820078e527..ceeeacb335c7 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -558,6 +558,17 @@ BlockRV TracedScheduleNode::Blockize(const LoopRV& loop_rv, bool preserve_unit_i return new_block; } +BlockRV TracedScheduleNode::Blockize(const Array& blocks, bool preserve_unit_iters) { + BlockRV new_block = ConcreteScheduleNode::Blockize(blocks, preserve_unit_iters); + static const InstructionKind& kind = InstructionKind::Get("Blockize"); + trace_->Append(/*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{blocks}, + /*attrs=*/{Bool(preserve_unit_iters)}, + /*outputs=*/{new_block})); + return new_block; +} + void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin, bool preserve_unit_iters) { ConcreteScheduleNode::Tensorize(loop_rv, intrin, preserve_unit_iters); diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 16ec86f22709..2d47ee9aff12 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -111,6 +111,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) final; /******** Schedule: Blockize & Tensorize ********/ BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) final; + BlockRV Blockize(const Array& blocks, bool preserve_unit_iters) final; void Tensorize(const BlockRV& block_rv, const String& intrin, bool preserve_unit_iters) final; void Tensorize(const LoopRV& loop_rv, const String& intrin, bool preserve_unit_iters) final; /******** Schedule: Annotation ********/ diff --git a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py index 030e47ac581d..111448ea5791 100644 --- a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py +++ b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py @@ -493,7 +493,7 @@ def schedule_fn(sch, conv2d_block: Optional[BlockRV] = None) -> bool: sch.parallel(new_loops[4]) sch.unroll(new_loops[5]) # TODO(nverke): Add compute optimizations here. - sch.blockize(loop=oc_i) + sch.blockize(target=oc_i) sch.tensorize(oc_i, VRMPY_u8i8i32_VTCM_INTRIN) diff --git a/tests/python/unittest/test_meta_schedule_trace_apply.py b/tests/python/unittest/test_meta_schedule_trace_apply.py index 4d22c4ff8854..78b2fdbf3d66 100644 --- a/tests/python/unittest/test_meta_schedule_trace_apply.py +++ b/tests/python/unittest/test_meta_schedule_trace_apply.py @@ -2150,7 +2150,7 @@ def apply_trace(sch): l28, l29 = sch.split(loop=l21, factors=[None, 16], preserve_unit_iters=True) l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b1) sch.reorder(l34, l36, l29, l27, l25) - b38 = sch.blockize(loop=l29) + b38 = sch.blockize(target=l29) sch.annotate( block_or_loop=b38, ann_key="meta_schedule.auto_tensorize", @@ -2243,7 +2243,7 @@ def apply_trace(sch): l95, l96 = sch.split(loop=l91, factors=[None, 16], preserve_unit_iters=True) l97, l98, l99, l100, l101, l102, l103 = sch.get_loops(block=b86) sch.reorder(l102, l96, l94) - b104 = sch.blockize(loop=l96) + b104 = sch.blockize(target=l96) sch.annotate( block_or_loop=b104, ann_key="meta_schedule.auto_tensorize", @@ -2308,7 +2308,7 @@ def apply_trace(sch): l157, ) = sch.get_loops(block=b129) sch.reorder(l156, l144, l142) - b158 = sch.blockize(loop=l144) + b158 = sch.blockize(target=l144) sch.annotate( block_or_loop=b158, ann_key="meta_schedule.auto_tensorize", @@ -2351,7 +2351,7 @@ def apply_trace(sch): l191, ) = sch.get_loops(block=b159) sch.reorder(l190, l176, l174) - b192 = sch.blockize(loop=l176) + b192 = sch.blockize(target=l176) sch.annotate( block_or_loop=b192, ann_key="meta_schedule.auto_tensorize", @@ -2554,7 +2554,7 @@ def apply_trace(sch): l34, l35 = sch.split(loop=l26, factors=[None, 16], preserve_unit_iters=True) l36, l37, l38, l39, l40, l41, l42, l43, l44, l45, l46, l47 = sch.get_loops(block=b1) sch.reorder(l42, l43, l44, l45, l46, l35, l33) - b48 = sch.blockize(loop=l35) + b48 = sch.blockize(target=l35) sch.annotate(block_or_loop=b48, ann_key="meta_schedule.auto_tensorize", ann_val=VNNI_INTRIN) l49, l50, l51, l52, l53, l54, l55, l56, l57, l58 = sch.get_loops(block=b48) v59, v60, v61, v62 = sch.sample_perfect_tile( @@ -3119,7 +3119,7 @@ def apply_trace(sch: Schedule) -> None: l22, l23 = sch.split(loop=l15, factors=[None, 16], preserve_unit_iters=True) l24, l25, l26, l27, l28, l29, l30, l31 = sch.get_loops(block=b1) sch.reorder(l28, l30, l23, l21, l19) - b32 = sch.blockize(loop=l23) + b32 = sch.blockize(target=l23) sch.annotate( block_or_loop=b32, ann_key="meta_schedule.auto_tensorize", @@ -3212,7 +3212,7 @@ def apply_trace(sch: Schedule) -> None: l89, l90 = sch.split(loop=l85, factors=[None, 16], preserve_unit_iters=True) l91, l92, l93, l94, l95, l96, l97 = sch.get_loops(block=b80) sch.reorder(l96, l90, l88) - b98 = sch.blockize(loop=l90) + b98 = sch.blockize(target=l90) sch.annotate( block_or_loop=b98, ann_key="meta_schedule.auto_tensorize", @@ -3277,7 +3277,7 @@ def apply_trace(sch: Schedule) -> None: l151, ) = sch.get_loops(block=b123) sch.reorder(l150, l138, l136) - b152 = sch.blockize(loop=l138) + b152 = sch.blockize(target=l138) sch.annotate( block_or_loop=b152, ann_key="meta_schedule.auto_tensorize", @@ -3320,7 +3320,7 @@ def apply_trace(sch: Schedule) -> None: l185, ) = sch.get_loops(block=b153) sch.reorder(l184, l170, l168) - b186 = sch.blockize(loop=l170) + b186 = sch.blockize(target=l170) sch.annotate( block_or_loop=b186, ann_key="meta_schedule.auto_tensorize", diff --git a/tests/python/unittest/test_tir_schedule_blockize.py b/tests/python/unittest/test_tir_schedule_blockize.py index cd4ce663e58e..d151e4b43809 100644 --- a/tests/python/unittest/test_tir_schedule_blockize.py +++ b/tests/python/unittest/test_tir_schedule_blockize.py @@ -305,5 +305,54 @@ def after_single_elementwise_int64_blockize_preserve_unit_iters( verify_trace_roundtrip(sch=s, mod=single_elementwise_int64) +def test_blockize_blocks(): + @T.prim_func + def blocks_func(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32")) -> None: + for m in T.serial(6): + for i, j in T.grid(3, 1): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * 2.0 + + for i, j in T.grid(128, 64): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj + 64]) + T.writes(B[vi, vj + 64]) + B[vi, vj + 64] = A[vi, vj + 64] * 3.0 + + @T.prim_func + def after_blocks_blockize( + A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32") + ) -> None: + for m in range(6): + with T.block("outer_B_C_"): + vi_o = T.axis.spatial(1, 0) + vj_o = T.axis.spatial(1, 0) + T.reads(A[0:128, 0:128]) + T.writes(B[0:128, 0:128]) + for i, j in T.grid(3, 1): + with T.block("B"): + vi_i = T.axis.spatial(3, i) + T.reads(A[vi_i, 0]) + T.writes(B[vi_i, 0]) + B[vi_i, 0] = A[vi_i, 0] * T.float32(2) + for i, j in T.grid(128, 64): + with T.block("C"): + vi_i, vj_i = T.axis.remap("SS", [i, j]) + T.reads(A[vi_i, vj_i + 64]) + T.writes(B[vi_i, vj_i + 64]) + B[vi_i, vj_i + 64] = A[vi_i, vj_i + 64] * T.float32(3) + + s = tir.Schedule(blocks_func, debug_mask="all") + blocks = [s.get_block("B"), s.get_block("C")] + s.blockize(blocks, preserve_unit_iters=False) + expected = after_blocks_blockize + tvm.ir.assert_structural_equal(s.mod["main"], expected) + verify_trace_roundtrip(sch=s, mod=blocks_func) + + if __name__ == "__main__": tvm.testing.main()