diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 9cad4ffd6989..efa26a31d02e 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -346,6 +346,27 @@ class RewriteSimplifier { /*! \brief Return the currently enabled extensions */ TVM_DLL Extension GetEnabledExtensions() const; + /*! \brief Return the statistics counters */ + TVM_DLL ObjectRef GetStatsCounters() const; + + /*! \brief Reset the statistics counters */ + TVM_DLL void ResetStatsCounters(); + + /*! \brief Set the maximum allowed number of rewrite steps + * + * By default, the simplifier may perform as many steps as are + * required. If a positive limit is set, then the simplifier will + * throw an exception when exceeding that number of rewrite steps. + * This allows tests to guard against performance regressions. + * + * Note: To maintain accurate usage counters, `Analyzer` instances + * should be re-used wherever possible. For example, TIR + * transformations should declare a single `Analyzer` that is used + * throughout the pass, and utility functions should receive an + * `Analyzer*` from their calling scope. + */ + TVM_DLL void SetMaximumRewriteSteps(int64_t maximum); + private: friend class Analyzer; friend class ConstraintContext; diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index b72c6ab2eb7e..559d9cc43c79 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -96,6 +96,8 @@ def __init__(self): self._modular_set = _mod("modular_set") self._simplify = _mod("Simplify") self._rewrite_simplify = _mod("rewrite_simplify") + self._get_rewrite_simplify_stats = _mod("get_rewrite_simplify_stats") + self._reset_rewrite_simplify_stats = _mod("reset_rewrite_simplify_stats") self._canonical_simplify = _mod("canonical_simplify") self._int_set = _mod("int_set") self._enter_constraint_context = _mod("enter_constraint_context") @@ -167,6 +169,13 @@ def rewrite_simplify(self, expr): """ return self._rewrite_simplify(expr) + @property + def rewrite_simplify_stats(self): + return self._get_rewrite_simplify_stats() + + def reset_rewrite_simplify_stats(self): + self._reset_rewrite_simplify_stats() + def canonical_simplify(self, expr): """Simplify expression via canonicalization. diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index f744bed4f4f4..722a2cd00e75 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -228,6 +228,13 @@ TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValu } else if (name == "rewrite_simplify") { return PackedFunc( [self](TVMArgs args, TVMRetValue* ret) { *ret = self->rewrite_simplify(args[0]); }); + } else if (name == "get_rewrite_simplify_stats") { + return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { + *ret = self->rewrite_simplify.GetStatsCounters(); + }); + } else if (name == "reset_rewrite_simplify_stats") { + return PackedFunc( + [self](TVMArgs args, TVMRetValue* ret) { self->rewrite_simplify.ResetStatsCounters(); }); } else if (name == "canonical_simplify") { return PackedFunc( [self](TVMArgs args, TVMRetValue* ret) { *ret = self->canonical_simplify(args[0]); }); diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index acd74b7031e7..40088fd963d7 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -58,25 +58,33 @@ using namespace tir; // macro for doing simple rewrite #define TVM_TRY_REWRITE(SrcExpr, ResExpr) \ + RecordAttemptedRewrite(); \ if ((SrcExpr).Match(ret)) { \ + RecordRewrite(); \ return (ResExpr).Eval(); \ } // macro for rewrite + recursively rewrite ResExpr #define TVM_TRY_RECURSIVE_REWRITE(SrcExpr, ResExpr) \ + RecordAttemptedRewrite(); \ if ((SrcExpr).Match(ret)) { \ + RecordRewrite(); \ return RecursiveRewrite((ResExpr).Eval()); \ } // macro rewrite only if CondExor is true after match. #define TVM_TRY_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \ + RecordAttemptedRewrite(); \ if ((SrcExpr).Match(ret, [&]() { return (CondExpr); })) { \ + RecordRewrite(); \ return (ResExpr).Eval(); \ } // macro rewrite + recursive_rewrite only if CondExor is true after match. #define TVM_TRY_RECURSIVE_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \ + RecordAttemptedRewrite(); \ if ((SrcExpr).Match(ret, [&]() { return (CondExpr); })) { \ + RecordRewrite(); \ return RecursiveRewrite((ResExpr).Eval()); \ } @@ -211,6 +219,11 @@ CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, int64_t val return CompareResult::kUnknown; } +PrimExpr RewriteSimplifier::Impl::VisitExpr(const PrimExpr& e) { + stats_.nodes_visited++; + return IRMutatorWithAnalyzer::VisitExpr(e); +} + void RewriteSimplifier::Impl::Update(const Var& var, const PrimExpr& info, bool can_override) { if (!can_override) { auto it = var_map_.find(var); @@ -359,6 +372,7 @@ std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c literal_constraints_.push_back(Not(negation)); } } + stats_.constraints_entered++; size_t new_literal_size = literal_constraints_.size(); auto frecover = [old_literal_size, new_literal_size, this]() { ICHECK_EQ(literal_constraints_.size(), new_literal_size); @@ -2150,9 +2164,30 @@ RewriteSimplifier::Extension RewriteSimplifier::GetEnabledExtensions() const { return impl_->GetEnabledExtensions(); } +ObjectRef RewriteSimplifier::GetStatsCounters() const { return impl_->GetStatsCounters(); } + +void RewriteSimplifier::ResetStatsCounters() { impl_->ResetStatsCounters(); } + +void RewriteSimplifier::SetMaximumRewriteSteps(int64_t maximum) { + impl_->SetMaximumRewriteSteps(maximum); +} + RewriteSimplifier::RewriteSimplifier(Analyzer* parent) : impl_(new Impl(parent)) {} RewriteSimplifier::~RewriteSimplifier() { delete impl_; } +TVM_REGISTER_NODE_TYPE(RewriteSimplifierStatsNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* ptr = node.as(); + p->stream << "RewriteSimplifierStats(nodes_visited = " << ptr->nodes_visited + << ", constraints_entered = " << ptr->constraints_entered + << ", rewrites_attempted = " << ptr->rewrites_attempted + << ", rewrites_performed = " << ptr->rewrites_performed + << ", max_recursive_depth = " << ptr->max_recursive_depth + << ", num_recursive_rewrites = " << ptr->num_recursive_rewrites << ")"; + }); + } // namespace arith } // namespace tvm diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index 22e7a0b74c40..7c4b0eab2224 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -27,6 +27,7 @@ #include #include +#include #include #include @@ -39,6 +40,41 @@ namespace arith { using namespace tir; +/* \brief Usage counters for RewriteSimplifier + * + * These are intended for debug and testing purposes, to ensure that + * PrimExpr simplifications and TIR passes do not require an excessive + */ +struct RewriteSimplifierStatsNode : Object { + int64_t nodes_visited{0}; + int64_t constraints_entered{0}; + int64_t rewrites_attempted{0}; + int64_t rewrites_performed{0}; + int64_t max_recursive_depth{0}; + int64_t num_recursive_rewrites{0}; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("nodes_visited", &nodes_visited); + v->Visit("constraints_entered", &constraints_entered); + v->Visit("rewrites_attempted", &rewrites_attempted); + v->Visit("rewrites_performed", &rewrites_performed); + v->Visit("max_recursive_depth", &max_recursive_depth); + v->Visit("num_recursive_rewrites", &num_recursive_rewrites); + } + + static constexpr const char* _type_key = "arith.RewriteSimplifierStats"; + TVM_DECLARE_FINAL_OBJECT_INFO(RewriteSimplifierStatsNode, Object); +}; + +struct RewriteSimplifierStats : ObjectRef { + explicit RewriteSimplifierStats(RewriteSimplifierStatsNode data) { + data_ = make_object(data); + } + + TVM_DEFINE_OBJECT_REF_METHODS(RewriteSimplifierStats, ObjectRef, RewriteSimplifierStatsNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(RewriteSimplifierStatsNode); +}; + /*! * \brief Rewrite-based simplifier. * @@ -50,6 +86,8 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { explicit Impl(Analyzer* parent) : IRMutatorWithAnalyzer(parent) {} + PrimExpr VisitExpr(const PrimExpr& e) override; + void Update(const Var& var, const PrimExpr& info, bool override_info); PrimExpr VisitExpr_(const AddNode* op) override; PrimExpr VisitExpr_(const SubNode* op) override; @@ -87,9 +125,27 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { /*! \brief Return the currently enabled extensions */ Extension GetEnabledExtensions() const; + RewriteSimplifierStats GetStatsCounters() const { return RewriteSimplifierStats(stats_); } + + void ResetStatsCounters() { stats_ = {}; } + + void SetMaximumRewriteSteps(int64_t maximum) { maximum_rewrite_steps_ = maximum; } + protected: + int64_t maximum_rewrite_steps_{0}; + RewriteSimplifierStatsNode stats_; + + void RecordAttemptedRewrite() { stats_.rewrites_attempted++; } + void RecordRewrite() { + stats_.rewrites_performed++; + + ICHECK(maximum_rewrite_steps_ <= 0 || stats_.rewrites_performed <= maximum_rewrite_steps_) + << "RewriteSimplifier exceeded maximum number of rewrites allowed (" + << maximum_rewrite_steps_ << ")"; + } + // counter to record recursive rewrite depth. - int recur_depth_{0}; + int64_t recur_depth_{0}; // internal variable map std::unordered_map var_map_; @@ -103,7 +159,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { bool recursively_visiting_boolean_{false}; // maximum number of recursion allowed during a single pass. - static const constexpr int kMaxRecurDepth = 5; + static const constexpr int64_t kMaxRecurDepth = 5; /*! * \brief try to compare x against val. * \param x The expression to be evaluated. @@ -177,8 +233,10 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { // we limit maximum depth of recursive rewrite allowed to // avoid infinite loop PrimExpr RecursiveRewrite(const PrimExpr& x) { + stats_.num_recursive_rewrites++; if (recur_depth_ >= kMaxRecurDepth) return x; ++recur_depth_; + stats_.max_recursive_depth = std::max(recur_depth_, stats_.max_recursive_depth); PrimExpr res = this->VisitExpr(x); --recur_depth_; return res; diff --git a/src/tir/analysis/control_flow_graph.cc b/src/tir/analysis/control_flow_graph.cc index e2b935b19046..1065ad3bf1e0 100644 --- a/src/tir/analysis/control_flow_graph.cc +++ b/src/tir/analysis/control_flow_graph.cc @@ -820,8 +820,9 @@ BufferTouch ControlFlowGraph::ControlFlowBlock::MakeBufferTouch(ControlFlowGraph return buffer_touch; } -ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, size_t max_revisits) - : max_revisits_(max_revisits) { +ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, int64_t max_simplification_steps, + size_t max_revisits) + : max_revisits_(max_revisits), max_simplification_steps_(max_simplification_steps) { ControlFlowGraphBuilder::Build(this, stmt); ForwardPropagateKnownValues(); BackwardPropagateUnusedValues(); @@ -1377,6 +1378,7 @@ void ControlFlowGraph::ForwardPropagateKnownValues(std::optional flow_fr std::unordered_map visit_count_lookup; Analyzer analyzer; + analyzer.rewrite_simplify.SetMaximumRewriteSteps(max_simplification_steps_); analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension( arith::RewriteSimplifier::kTransitivelyProveInequalities | arith::RewriteSimplifier::kConvertBooleanToAndOfOrs | @@ -1510,6 +1512,7 @@ void ControlFlowGraph::BackwardPropagateUnusedValues(std::optional flow_ std::unordered_map visit_count_lookup; Analyzer analyzer; + analyzer.rewrite_simplify.SetMaximumRewriteSteps(max_simplification_steps_); analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension( arith::RewriteSimplifier::kTransitivelyProveInequalities | arith::RewriteSimplifier::kConvertBooleanToAndOfOrs | diff --git a/src/tir/analysis/control_flow_graph.h b/src/tir/analysis/control_flow_graph.h index f2e46b2478a3..35934351dce0 100644 --- a/src/tir/analysis/control_flow_graph.h +++ b/src/tir/analysis/control_flow_graph.h @@ -399,7 +399,8 @@ class ControlFlowGraph { public: /* \brief Extract the touch pattern from a TIR statement */ - explicit ControlFlowGraph(const Stmt& stmt, size_t max_revisits = 5); + explicit ControlFlowGraph(const Stmt& stmt, int64_t max_simplification_steps = 0, + size_t max_revisits = 5); /* \brief Check if a write is overwritten without impacting final results * @@ -655,6 +656,9 @@ class ControlFlowGraph { /*! \brief The maximum number of revisits while flowing constraints */ size_t max_revisits_; + + /*! \brief The maximum number of revisits while flowing constraints */ + int64_t max_simplification_steps_; }; } // namespace tir diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index 4179b00a3684..7951a2befa20 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -42,6 +42,7 @@ namespace tir { struct RemoveNoOpConfigNode : public tvm::AttrsNode { bool use_dataflow_analysis; + int64_t max_simplification_steps; TVM_DECLARE_ATTRS(RemoveNoOpConfigNode, "tir.transform.RemoveNoOpConfig") { TVM_ATTR_FIELD(use_dataflow_analysis) @@ -49,6 +50,12 @@ struct RemoveNoOpConfigNode : public tvm::AttrsNode { "If true, known buffer values are propagated and used " "to statically prove statements as no-ops.") .set_default(false); + TVM_ATTR_FIELD(max_simplification_steps) + .describe( + "If non-zero, RewriteSimplifier will throw an error " + "after the number of steps specified. " + "For use in debug and testing purposes.") + .set_default(0); } }; @@ -291,14 +298,19 @@ Pass RemoveNoOp() { RemoveNoOpConfig config = ctx->GetConfig("tir.RemoveNoOp") .value_or(AttrsWithDefaultValues()); + if (config->use_dataflow_analysis) { - touch_pattern.emplace(f->body); + touch_pattern.emplace(f->body, config->max_simplification_steps); } arith::Analyzer analyzer; + analyzer.rewrite_simplify.SetMaximumRewriteSteps(config->max_simplification_steps); - auto* n = f.CopyOnWrite(); - n->body = NoOpRemover::Apply(std::move(n->body), &analyzer, std::move(touch_pattern), nullptr); + { + auto* write_ptr = f.CopyOnWrite(); + write_ptr->body = NoOpRemover::Apply(std::move(write_ptr->body), &analyzer, + std::move(touch_pattern), nullptr); + } return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.RemoveNoOp", {}); diff --git a/tests/python/unittest/test_tir_transform_remove_no_op.py b/tests/python/unittest/test_tir_transform_remove_no_op.py index 133ef01ed001..00452bb5bd0e 100644 --- a/tests/python/unittest/test_tir_transform_remove_no_op.py +++ b/tests/python/unittest/test_tir_transform_remove_no_op.py @@ -86,12 +86,14 @@ def main(A: T.Buffer((16), "int32"), B: T.Buffer((16), "int32")) -> None: class BaseBeforeAfter(tvm.testing.CompareBeforeAfter): use_dataflow_analysis = False + max_simplification_steps = 0 def transform(self): def inner(mod): config = { "tir.RemoveNoOp": { "use_dataflow_analysis": self.use_dataflow_analysis, + "max_simplification_steps": self.max_simplification_steps, } } with tvm.transform.PassContext(config=config): @@ -319,9 +321,16 @@ class TestRemoveOverwrittenPredicatedLoopWithIdenticalCondition(BaseBeforeAfter) Similar to TestKeepPartiallyOverwrittenLoop, except the first loop has the same predicate as the second, and can therefore be removed. + + In the past, this test has had performance regressions in which + the runtime increased from a few seconds to nearly ten minutes. + The "max_simplification_steps" parameter is set at twice the + current number of steps required, in order to prevent similar + performance regression. """ use_dataflow_analysis = True + max_simplification_steps = 200000 def before(A: T.Buffer(16, "int32")): for i in T.serial(16): @@ -347,9 +356,16 @@ class TestRemoveOverwrittenPredicatedLoopWithProvableCondition(BaseBeforeAfter): loop's predicate. So long as the regions written in the first loop are a subset of those written in the second loop, they can be removed. + + In the past, this test has had performance regressions in which + the runtime increased from a few seconds to nearly ten minutes. + The "max_simplification_steps" parameter is set at twice the + current number of steps required, in order to prevent similar + performance regression. """ use_dataflow_analysis = True + max_simplification_steps = 200000 def before(A: T.Buffer(16, "int32")): for i in T.serial(16):