Skip to content

Commit

Permalink
[relay][simplify_expr]: Add pass to remove trivial transpose ops (apa…
Browse files Browse the repository at this point in the history
…che#14858)

[relay][simplify_expr]: Add pattern to remove trivial transpose ops
  • Loading branch information
f2013519 authored May 16, 2023
1 parent eb1ea97 commit f6bbe94
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 36 deletions.
111 changes: 75 additions & 36 deletions src/relay/transforms/simplify_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,38 @@ class SimplifyCastClip : public DFPatternRewrite {
DFPattern clip_, cast_;
};

/*!
* \brief Return the axis order for layout transform and transpose
* ops.
*/
static std::vector<int> GetTransposeAxisOrder(const Call& call, int ndim) {
std::vector<int> attr_axes;
if (auto attr = call->attrs.as<TransposeAttrs>()) {
if (attr->axes.defined()) {
for (int i = 0; i < ndim; ++i) {
int64_t axis = attr->axes[i].IntValue();
axis += (axis < 0) ? ndim : 0;
attr_axes.push_back(axis);
}
} else {
// Empty axes means reverse
for (int i = ndim - 1; i >= 0; --i) {
attr_axes.push_back(i);
}
}
} else if (auto attr = call->attrs.as<LayoutTransformAttrs>()) {
Layout src_layout(attr->src_layout);
Layout dst_layout(attr->dst_layout);
for (int i = 0; i < ndim; ++i) {
attr_axes.push_back(src_layout.IndexOf(dst_layout[i]));
}
} else {
CHECK(false) << "Expected transpose or layout_transform, but got "
<< Downcast<Op>(call->op)->name;
}
return std::move(attr_axes);
}

/*!
* \brief SimplifyTranspose matches the pattern of consecutive transpose op,
* and merges or cancels them.
Expand Down Expand Up @@ -316,19 +348,7 @@ class SimplifyTranspose : public DFPatternRewrite {
it++;
}

// Check if the transpose is still required
bool need_transpose = false;
for (int i = 0; i < ndim; ++i) {
if (axes[i] != i) {
need_transpose = true;
break;
}
}

if (need_transpose) {
return MakeTranspose(x, axes);
}
return x;
return MakeTranspose(x, axes);
}

String PermuteLayout(const String& layout, std::vector<int> axes_order) const {
Expand Down Expand Up @@ -431,32 +451,50 @@ class SimplifyTranspose : public DFPatternRewrite {
return Downcast<Call>(output_layout_trans);
}

std::vector<int> GetTransposeAxisOrder(const Call& call, int ndim) const {
std::vector<int> attr_axes;
if (auto attr = call->attrs.as<TransposeAttrs>()) {
if (attr->axes.defined()) {
for (int i = 0; i < ndim; ++i) {
int64_t axis = attr->axes[i].IntValue();
axis += (axis < 0) ? ndim : 0;
attr_axes.push_back(axis);
}
} else {
// Empty axes means reverse
for (int i = ndim - 1; i >= 0; --i) {
attr_axes.push_back(i);
}
private:
/*! \brief Pattern input */
DFPattern x_;
};

/*!
* \brief SimplifyNoOpTranspose matches the pattern of transpose or
* layout transform ops which do not change the layout or rank and
* removes the op.
*/
class SimplifyNoOpTranspose : public DFPatternRewrite {
public:
SimplifyNoOpTranspose() {
x_ = IsWildcard();
auto trans1 = IsOp("transpose") || IsOp("layout_transform");
pattern_ = trans1({x_});
}

Expr Callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
auto x = node_map[x_][0];
Call trans_call = Downcast<Call>(post);

// Do not remove ops which change rank
if (auto attr = trans_call->attrs.as<LayoutTransformAttrs>()) {
if (attr->src_layout != attr->dst_layout) {
return post;
}
} else if (auto attr = call->attrs.as<LayoutTransformAttrs>()) {
Layout src_layout(attr->src_layout);
Layout dst_layout(attr->dst_layout);
for (int i = 0; i < ndim; ++i) {
attr_axes.push_back(src_layout.IndexOf(dst_layout[i]));
}

int ndim = Downcast<TensorType>(pre->checked_type())->shape.size();
auto axes = GetTransposeAxisOrder(trans_call, ndim);

bool need_transpose = false;
for (int i = 0; i < ndim; ++i) {
if (axes[i] != i) {
need_transpose = true;
break;
}
} else {
CHECK(false) << "Expected transpose or layout_transform, but got "
<< Downcast<Op>(call->op)->name;
}
return std::move(attr_axes);

if (!need_transpose) return x;

return post;
}

private:
Expand Down Expand Up @@ -1037,6 +1075,7 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
composer.AddRewrite<EliminateIdentityRewrite>();
composer.AddRewrite<SimplifyReshape>();
composer.AddRewrite<SimplifyTranspose>();
composer.AddRewrite<SimplifyNoOpTranspose>();
composer.AddRewrite<SimplifySameCast>();
composer.AddRewrite<SimplifyConsecutiveCast>();
composer.AddRewrite<FullElementwise>();
Expand Down
22 changes: 22 additions & 0 deletions tests/python/relay/test_pass_simplify_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,27 @@ def expected10():
y = relay.nn.relu(y)
return relay.Function([x], y)

def before11():
"""
Remove trivial no op transpose ops
Input:
op1 -> relay.transpose(x, axes=[0, 1, 2, 3]) -> op2
Simplified:
op1 -> op2
"""
x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
y = relay.transpose(x, axes=[0, 1, 2, 3])
y = relay.nn.relu(y)
y = relay.layout_transform(y, "NCHW", "NCHW")
return relay.Function([x], y)

def expected11():
x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
y = relay.nn.relu(x)
return relay.Function([x], y)

for before, expected in [
[before1(), expected1()],
[before2(), expected2()],
Expand All @@ -277,6 +298,7 @@ def expected10():
[before8(), expected8()],
[before9(), expected9()],
[before10(), expected10()],
[before11(), expected11()],
]:
after = run_opt_pass(before, transform.SimplifyExpr())
expected = run_opt_pass(expected, transform.InferType())
Expand Down

0 comments on commit f6bbe94

Please sign in to comment.