Skip to content

Commit

Permalink
[OpenCLML] Transposed convolution support and other fixes (apache#14767)
Browse files Browse the repository at this point in the history
* [OpenCLML] Transposed convolution support and other fixes

Added support for transposed convolution.
Epsilon support for batchnorm op added - part of v3.0.
CLML version query bug fixed.

* * review comments and various compilation warnings.

* * lint issues.
  • Loading branch information
srkreddy1238 authored May 10, 2023
1 parent 3829ebb commit cca7d78
Show file tree
Hide file tree
Showing 8 changed files with 271 additions and 104 deletions.
3 changes: 1 addition & 2 deletions apps/cpp_rtvm/tvm_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ namespace runtime {
/*!
* \brief various meta information related to the compiled TVM model.
*/
typedef struct {
public:
typedef struct _TVMMetaInfo {
int n_inputs;
int n_outputs;
std::map<std::string, std::pair<std::vector<int>, std::string>> input_info;
Expand Down
1 change: 1 addition & 0 deletions cmake/modules/LibInfo.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ function(add_lib_info src_file)
TVM_INFO_USE_VULKAN="${USE_VULKAN}"
TVM_INFO_USE_CLML="${USE_CLML}"
TVM_INFO_USE_CLML_GRAPH_EXECUTOR="${USE_CLML_GRAPH_EXECUTOR}"
TVM_INFO_USE_TVM_CLML_VERSION="${CLML_VERSION_MAJOR}"
TVM_INFO_USE_UMA="${USE_UMA}"
TVM_INFO_USE_VERILATOR="${USE_VERILATOR}"
TVM_INFO_USE_CCACHE="${USE_CCACHE}"
Expand Down
45 changes: 44 additions & 1 deletion python/tvm/relay/op/contrib/clml.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
def clml_sdk_version():
"""Utility function to get clml version version"""

return tvm.support.libinfo().get("TVM_CLML_VERSION", 2)
return int(tvm.support.libinfo().get("TVM_CLML_VERSION", 2))


def is_clml_runtime_enabled():
Expand Down Expand Up @@ -155,6 +155,7 @@ def alter_conv(attrs, inputs, tinfos, out_type):
seq = tvm.transform.Sequential(
[
transform.ConvertLayout({"nn.conv2d": ["NCHW", "OIHW"]}),
transform.ConvertLayout({"nn.conv2d_transpose": ["NCHW", "OIHW"]}),
transform.AlterOpLayout(),
transform.FoldConstant(),
]
Expand Down Expand Up @@ -203,6 +204,22 @@ def conv_pattern():
pattern = pattern.optional(is_op("clip"))
return pattern

def conv_transpose_pattern():
"""Create a transposed convolution pattern."""
pattern = is_op("nn.conv2d_transpose")(wildcard(), is_constant())
pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant()))
pattern = pattern.optional(lambda x: is_op("add")(x, is_constant()))
pattern = pattern.optional(
lambda x: is_tuple_get_item(
is_op("nn.batch_norm")(
x, is_constant(), is_constant(), is_constant(), is_constant()
)
)
)
pattern = pattern.optional(is_op("nn.relu"))
pattern = pattern.optional(is_op("clip"))
return pattern

def pad_conv_pattern():
"""Create a pad with convolution pattern."""
pattern = is_op("nn.pad")(wildcard(), is_constant())
Expand Down Expand Up @@ -300,6 +317,31 @@ def check_conv(extract):
return False
return True

def check_conv_transpose(extract):
"""Check transposed conv pattern is supported by CLML."""
call = extract
if isinstance(call, tvm.relay.expr.TupleGetItem):
call = call.tuple_value
elif call.op.name == "nn.relu":
call = call.args[0]
if isinstance(call, tvm.relay.expr.TupleGetItem):
call = call.tuple_value
elif call.op.name == "clip":
if call.attrs["a_min"] != 0.0 or call.attrs["a_max"] != 6.0:
return False
call = call.args[0]
if isinstance(call, tvm.relay.expr.TupleGetItem):
call = call.tuple_value

while call.op.name != "nn.conv2d_transpose":
call = call.args[0]

attrs = call.attrs
if attrs.data_layout != "NCHW":
return False

return True

def check_binary_op(extract):
call = extract
if len(call.args[1].checked_type.shape) > 0:
Expand Down Expand Up @@ -340,6 +382,7 @@ def check_default_op(extract):
return [
("clml.pad_conv2d", pad_conv_pattern(), check_conv),
("clml.conv2d", conv_pattern(), check_conv),
("clml.conv2d_transpose", conv_transpose_pattern(), check_conv_transpose),
("clml.dense", dense_pattern(), check_default_op),
("clml.pad", pad_pattern(), check_pad_op),
("clml.concat", concat_pattern(), check_concat_op),
Expand Down
40 changes: 24 additions & 16 deletions src/relay/backend/contrib/clml/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class CLMLJSONSerializer : public backend::contrib::JSONSerializer {
ICHECK(comp.defined()) << "CLML JSON runtime only supports composite functions.";
const std::string name = comp.value();
std::shared_ptr<JSONGraphNode> json_node;
if (name == "clml.conv2d" || name == "clml.pad_conv2d") {
if (name == "clml.conv2d" || name == "clml.pad_conv2d" || name == "clml.conv2d_transpose") {
json_node = CreateCompositeConvJSONNode(cn);
} else if (name == "clml.batch_norm") {
json_node = CreateBatchNormJSONNode(cn);
Expand Down Expand Up @@ -169,7 +169,10 @@ class CLMLJSONSerializer : public backend::contrib::JSONSerializer {
current_call = current_call->args[0].as<CallNode>();
}
// Enforce a convolution node exists at this point during traversal
ICHECK(backend::IsOp(current_call, "nn.conv2d"));
if (!backend::IsOp(current_call, "nn.conv2d") &&
!backend::IsOp(current_call, "nn.conv2d_transpose")) {
LOG(FATAL) << "Can't find primary op in Convolution node";
}
nodes.conv = current_call;
if (!current_call->args.empty() && current_call->args[0]->IsInstance<CallNode>()) {
current_call = current_call->args[0].as<CallNode>();
Expand All @@ -189,22 +192,27 @@ class CLMLJSONSerializer : public backend::contrib::JSONSerializer {
std::shared_ptr<JSONGraphNode> CreateCompositeConvJSONNode(const CallNode* cn) {
CompositeConvNode nodes = UnpackCompositeConvolution(cn);

const auto* conv_attr = nodes.conv->attrs.as<Conv2DAttrs>();
ICHECK(conv_attr);

std::string name;
std::string name_prefix = "nn";

// Distinguish between normal and depth-wise convolution
if (conv_attr->channels.defined() &&
tvm::tir::ExprDeepEqual()(conv_attr->channels, conv_attr->groups) &&
conv_attr->groups != 1) {
name = "depthwise_conv2d";
ICHECK(conv_attr->kernel_layout == "IOHW")
<< "Kernel layout must be IHWO, has the module been pre-processed correctly?";
} else {
name = "conv2d";
ICHECK(conv_attr->kernel_layout == "OIHW")
if (backend::IsOp(nodes.conv, "nn.conv2d")) {
const auto* conv_attr = nodes.conv->attrs.as<Conv2DAttrs>();
ICHECK(conv_attr);
if (conv_attr->channels.defined() &&
tvm::tir::ExprDeepEqual()(conv_attr->channels, conv_attr->groups) &&
conv_attr->groups != 1) {
name = "depthwise_conv2d";
ICHECK(conv_attr->kernel_layout == "IOHW")
<< "Kernel layout must be IHWO, has the module been pre-processed correctly?";
} else {
name = "conv2d";
ICHECK(conv_attr->kernel_layout == "OIHW")
<< "Kernel layout must be OHWI, has the module been pre-processed correctly?";
}
} else if (backend::IsOp(nodes.conv, "nn.conv2d_transpose")) {
name = "conv2d_transpose";
const auto* conv_transpose_attr = nodes.conv->attrs.as<Conv2DTransposeAttrs>();
ICHECK(conv_transpose_attr);
ICHECK(conv_transpose_attr->kernel_layout == "OIHW")
<< "Kernel layout must be OHWI, has the module been pre-processed correctly?";
}

Expand Down
Loading

0 comments on commit cca7d78

Please sign in to comment.