Skip to content

Commit

Permalink
Fix handling of nodes that get assigned to kMSInternalNHWCDomain when…
Browse files Browse the repository at this point in the history
… loading an ORT format model. (microsoft#20379)

Fix handling of nodes that get assigned to kMSInternalNHWCDomain when loading an ORT format model. The ORT format model doesn't contain information about kMSInternalNHWCDomain since it is set during layout transformation. Fall back to known domains instead.
  • Loading branch information
edgchen1 authored Apr 23, 2024
1 parent c7de4de commit 3270a00
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 36 deletions.
28 changes: 27 additions & 1 deletion onnxruntime/core/framework/kernel_type_str_resolver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,36 @@ namespace fb = flatbuffers;

namespace onnxruntime {

static OpKernelTypeStrMap::const_iterator LookUpOpId(const OpIdentifier& op_id,
const OpKernelTypeStrMap& map) {
auto op_it = map.find(op_id);

if (op_it == map.end() && op_id.domain == kMSInternalNHWCDomain) {
// Special case for kMSInternalNHWCDomain.
// kMSInternalNHWCDomain is set (replacing the original domain) by ORT during layout transformation.
// However, ORT format models contain kernel type string information with the original domain.
// kMSInternalNHWCDomain maps to one of these alternate domains, so fall back to them.
constexpr std::array alternate_domains{
std::string_view{kOnnxDomain},
std::string_view{kMSDomain},
};

for (auto alternate_domain : alternate_domains) {
const auto alternate_op_id = OpIdentifier{std::string{alternate_domain}, op_id.op_type, op_id.since_version};
op_it = map.find(alternate_op_id);
if (op_it != map.end()) {
break;
}
}
}

return op_it;
}

Status KernelTypeStrResolver::ResolveKernelTypeStr(const Node& node, std::string_view kernel_type_str,
gsl::span<const ArgTypeAndIndex>& resolved_args) const {
const auto op_id = utils::MakeOpId(node);
const auto op_it = op_kernel_type_str_map_.find(op_id);
const auto op_it = LookUpOpId(op_id, op_kernel_type_str_map_);
ORT_RETURN_IF(op_it == op_kernel_type_str_map_.end(), "Failed to find op_id: ", op_id);
const auto& type_str_map = op_it->second;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,49 +201,59 @@ TEST(InternalTestingEP, TestMixOfStaticAndCompiledKernels) {
}

TEST(InternalTestingEP, TestNhwcConversionOfStaticKernels) {
// the internal NHWC domain supports opset 11 and later
const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "squeezenet/model_opset11.onnx";
auto run_test = [&](const ORTCHAR_T* model_path) {
SCOPED_TRACE("model path: " + ToUTF8String(model_path));

SessionOptions so;
// set this if you want to manually inspect the optimized model
// so.optimized_model_filepath = ORT_MODEL_FOLDER "squeezenet/model.test_output.onnx";
InferenceSessionWrapper session(so, GetEnvironment());

const std::unordered_set<std::string> supported_ops{"Conv", "Clip"};
auto ep = std::make_unique<InternalTestingExecutionProvider>(supported_ops,
std::unordered_set<std::string>{},
DataLayout::NHWC);
ep->EnableStaticKernels();
ASSERT_STATUS_OK(session.RegisterExecutionProvider(std::move(ep)));

ASSERT_STATUS_OK(session.Load(model_path));
ASSERT_STATUS_OK(session.Initialize());

SessionOptions so;
// set this if you want to manually inspect the optimized model
// so.optimized_model_filepath = ORT_MODEL_FOLDER "squeezenet/model.test_output.onnx";
InferenceSessionWrapper session(so, GetEnvironment());
const auto& graph = session.GetGraph();

const std::unordered_set<std::string> supported_ops{"Conv", "Clip"};
auto ep = std::make_unique<InternalTestingExecutionProvider>(supported_ops,
std::unordered_set<std::string>{},
DataLayout::NHWC);
ep->EnableStaticKernels();
ASSERT_STATUS_OK(session.RegisterExecutionProvider(std::move(ep)));
// all Conv nodes should have been converted to NHWC versions and
for (const auto& node : graph.Nodes()) {
if (node.OpType() == "Conv") {
ASSERT_EQ(node.Domain(), kMSInternalNHWCDomain);
}
}

ASSERT_STATUS_OK(session.Load(ort_model_path));
ASSERT_STATUS_OK(session.Initialize());
TensorShape input_shape_x{1, 3, 224, 224};
std::vector<float> input_x(input_shape_x.Size(), 1.f);
OrtValue ml_value_x;
CreateMLValue<float>(input_shape_x.GetDims(), input_x.data(), OrtMemoryInfo(), &ml_value_x);

const auto& graph = session.GetGraph();
NameMLValMap feeds;
feeds.insert(std::make_pair("data_0", ml_value_x));

// all Conv nodes should have been converted to NHWC versions and
for (const auto& node : graph.Nodes()) {
if (node.OpType() == "Conv") {
ASSERT_EQ(node.Domain(), kMSInternalNHWCDomain);
}
}
// prepare outputs
std::vector<std::string> output_names;
output_names.push_back("softmaxout_1");
std::vector<OrtValue> fetches;

TensorShape input_shape_x{1, 3, 224, 224};
std::vector<float> input_x(input_shape_x.Size(), 1.f);
OrtValue ml_value_x;
CreateMLValue<float>(input_shape_x.GetDims(), input_x.data(), OrtMemoryInfo(), &ml_value_x);

NameMLValMap feeds;
feeds.insert(std::make_pair("data_0", ml_value_x));
ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(session.Run(feeds, output_names, &fetches),
"Non-zero status code returned while running Conv node. Name:'Conv' "
"Status Message: TODO: add NHWC implementation here.");
};

// prepare outputs
std::vector<std::string> output_names;
output_names.push_back("softmaxout_1");
std::vector<OrtValue> fetches;
// the internal NHWC domain supports opset 11 and later
const ORTCHAR_T* onnx_model_path = ORT_MODEL_FOLDER "squeezenet/model_opset11.onnx";
run_test(onnx_model_path);

ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(session.Run(feeds, output_names, &fetches),
"Non-zero status code returned while running Conv node. Name:'Conv' "
"Status Message: TODO: add NHWC implementation here.");
// Note: Using ORT format model with runtime optimizations so that the Conv nodes are preserved in the graph,
// not converted into FusedConv nodes. The InternalTestingExecutionProvider handles Conv nodes.
const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "squeezenet/model_opset11.with_runtime_opt.ort";
run_test(ort_model_path);
}

// This test can be deprecated now as the code logic has been changed so the model is not applicable
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
model_opset11.with_runtime_opt.ort was generated by running the following command:
python -m onnxruntime.tools.convert_onnx_models_to_ort ./model_opset11.onnx --optimization_style Runtime

0 comments on commit 3270a00

Please sign in to comment.