Skip to content

Commit

Permalink
[ARITH] Enhance IterMapSimplify for symbolic (apache#14547)
Browse files Browse the repository at this point in the history
This PR refactors and enhances DetectIterMap and IterMapSimplify
to enable symbolic shape simplification. Specifically, we add
a routine to combine multiple IterSplitExpr into one if they
come from the same source.

It is helpful to distinguish iterator from normal constants
in the simplification process. IterMapSimplify takes advantage
of these information.

This improvements is helpful to simplify the indices in flattened buffer
when there is symbolic shape involved and normal simplifier.

Also updated FlattenBuffer to take benefit of the enhanced simplifier.
Test cases are added.

Co-authored-by: Junru Shao <[email protected]>
  • Loading branch information
tqchen and junrushao authored Apr 12, 2023
1 parent 1c52e63 commit 17f7db1
Show file tree
Hide file tree
Showing 25 changed files with 2,833 additions and 2,460 deletions.
3 changes: 2 additions & 1 deletion include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -349,12 +349,13 @@ IterMapResult DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range
* \param input_iters Map from variable to iterator's range.
* \param input_pred The predicate constraints on the input iterators
* \param check_level The iter mapping checking level.
* \param analyzer Analyzer used to get context information.
* \param simplify_trivial_iterators If true, iterators with unit extents are simplified
* \return The indices after rewrite
*/
Array<PrimExpr> IterMapSimplify(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& input_pred, IterMapLevel check_level,
bool simplify_trivial_iterators = true);
arith::Analyzer* analyzer, bool simplify_trivial_iterators = true);

/*!
* \brief Apply the inverse of the affine transformation to the outputs.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/arith/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .iter_affine_map import IterMapExpr, IterMark, IterSplitExpr, IterSumExpr
from .iter_affine_map import (
detect_iter_map,
iter_map_simplify,
normalize_iter_map_to_expr,
subspace_divide,
inverse_affine_iter_map,
Expand Down
43 changes: 43 additions & 0 deletions python/tvm/arith/iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,49 @@ def detect_iter_map(
)


def iter_map_simplify(
indices,
input_iters,
predicate=True,
check_level=IterMapLevel.Surjective,
simplify_trivial_iterators=True,
):
"""Simplify the indices using iter map detection.
Parameters
----------
indices : List[PrimExpr]
The input indices
input_iters : Map[Var, Range]
The domain of each input iterators.
predicate : PrimExpr
The predicate constraints on the input iterators
check_level : Union[str, IterMapLevel]
Checking level of iteration mapping
simplify_trivial_iterators: bool
If true, iterators with extent of 1 will be replaced with a
constant value.
Returns
-------
results : IterMapResult
The iter map matching result.
The result's .indices is empty array if no match can be found.
"""
if isinstance(check_level, str):
check_level = IterMapLevel.from_str(check_level)
elif check_level is None:
check_level = IterMapLevel.NoCheck
return _ffi_api.IterMapSimplify(
indices, input_iters, predicate, check_level, simplify_trivial_iterators
)


def normalize_iter_map_to_expr(expr):
"""Given an IterMapExpr, transform it to normal PrimExpr
Expand Down
10 changes: 8 additions & 2 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include "const_fold.h"
#include "pattern_match.h"
#include "product_normal_form.h"
#include "rewrite_simplify.h"

namespace tvm {
Expand Down Expand Up @@ -808,12 +809,17 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op) {
}

// normal path.
// this only happens when b is symbolic
a = Normalize(a);
b = Normalize(b);
if (op->a.same_as(a) && op->b.same_as(b)) {

PrimExpr ret = MulAndNormalize(a, b);
const MulNode* mul = ret.as<MulNode>();

if (mul && mul->a.same_as(op->a) && mul->b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
} else {
return Mul(a, b);
return ret;
}
}

Expand Down
14 changes: 10 additions & 4 deletions src/arith/ir_mutator_with_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,17 @@ namespace arith {
using namespace tir;

Stmt IRMutatorWithAnalyzer::VisitStmt_(const ForNode* op) {
analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
// record the loop variable as iterators
Range dom = Range::FromMinExtent(op->min, op->extent);
analyzer_->Bind(op->loop_var, dom);
iter_vars_.Set(op->loop_var, dom);
return StmtExprMutator::VisitStmt_(op);
}

Stmt IRMutatorWithAnalyzer::VisitStmt_(const BlockNode* op) {
for (const auto& iter_var : op->iter_vars) {
analyzer_->Bind(iter_var->var, iter_var->dom);
iter_vars_.Set(iter_var->var, iter_var->dom);
}
return StmtExprMutator::VisitStmt_(op);
}
Expand Down Expand Up @@ -75,7 +79,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) {
Optional<Stmt> else_case;
{
With<ConstraintContext> ctx(analyzer_, real_condition);
then_case = this->VisitStmt(op->then_case);
WithRecordIterPredicate(real_condition, [&] { then_case = this->VisitStmt(op->then_case); });
}
if (op->else_case) {
With<ConstraintContext> ctx(analyzer_, analyzer_->rewrite_simplify(Not(real_condition)));
Expand All @@ -102,7 +106,9 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == tir::attr::thread_extent || op->attr_key == tir::attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
ICHECK_NE(iv->thread_tag.length(), 0U);
analyzer_->Bind(iv->var, Range::FromMinExtent(0, op->value));
Range dom = Range::FromMinExtent(make_zero(op->value.dtype()), op->value);
analyzer_->Bind(iv->var, dom);
iter_vars_.Set(iv->var, dom);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
return stmt;
} else {
Expand Down Expand Up @@ -135,7 +141,7 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) {
PrimExpr true_value, false_value;
{
With<ConstraintContext> constraint(analyzer_, cond);
true_value = this->VisitExpr(op->args[1]);
WithRecordIterPredicate(cond, [&] { true_value = this->VisitExpr(op->args[1]); });
}
{
With<ConstraintContext> constraint(analyzer_, analyzer_->rewrite_simplify(Not(cond)));
Expand Down
29 changes: 28 additions & 1 deletion src/arith/ir_mutator_with_analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define TVM_ARITH_IR_MUTATOR_WITH_ANALYZER_H_

#include <tvm/arith/analyzer.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>

#include <utility>
Expand Down Expand Up @@ -63,8 +64,34 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator {
protected:
/*! \brief internal analyzer field. */
Analyzer* analyzer_;
// the following two fields are useful in case we want
// note however that iter map analysis are usually more
// expensive and we only encourage doing them during
// necessary cases like layout remapping
/*! \brief Recorded loop iterators */
Map<Var, Range> iter_vars_;
/*! \brief iterator predicates */
Array<PrimExpr> iter_predicates_;
/*!
* \brief Run callback while trying to record iter predicate
* \param conditon Condition to be checked.
* \param callback The callback to be called.
*/
template <typename FLambda>
void WithRecordIterPredicate(PrimExpr condition, FLambda callback) {
auto f_use_itervar = [this](const tir::VarNode* v) {
return iter_vars_.count(GetRef<tir::Var>(v));
};
// simple heuristics for detecting predicate
if (tir::UsesVar(condition, f_use_itervar)) {
iter_predicates_.push_back(condition);
callback();
iter_predicates_.pop_back();
} else {
callback();
}
}
};

} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_IR_MUTATOR_WITH_ANALYZER_H_
Loading

0 comments on commit 17f7db1

Please sign in to comment.