From afa1f63ea54ac9aa99a467f5d1b7dd46ab1807a5 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 10 May 2023 00:55:29 -0500 Subject: [PATCH] [TIR] Improved parameter name in DLTensor unpacking error messages (#14776) Previously, the parameter name depending only on the TIR variable name. In large IRModules, such as those used when executing end-to-end models, these parameter names may be resued across functions, and so the error message doesn't identify which `PrimFunc` should be investigated. This commit updates the parameter name to include the function name (e.g. `my_function.arg.my_param`) to help debugging in these cases. --- src/tir/transforms/make_packed_api.cc | 22 ++++++++++++------- .../aot/test_crt_forward_declarations.py | 4 ++-- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 2f5fa6572159..de1f17608273 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -194,12 +194,17 @@ PrimFunc MakePackedAPI(PrimFunc&& func) { for (int i = 0; i < static_cast(func_ptr->params.size()); ++i) { Var param = func_ptr->params[i]; - std::string param_name; - if (param->name_hint.defined() && (!param->name_hint.empty())) { - param_name = "arg." + param->name_hint; - } else { - param_name = "arg" + std::to_string(i); - } + std::string param_name = [&]() { + std::ostringstream oss; + oss << "arg"; + if (param->name_hint.defined() && (!param->name_hint.empty())) { + oss << "." << param->name_hint; + + } else { + oss << i; + } + return oss.str(); + }(); Var v_arg = Var(param_name, param->dtype); // Pluck the device API context out based on name @@ -252,11 +257,12 @@ PrimFunc MakePackedAPI(PrimFunc&& func) { // to use the args that may have no let binding yet. Therefore, hoisting let // binding for args before buffer declaration is needed. for (const auto& kv : var_def) { - binder.Bind(kv.second, kv.first, kv.first->name_hint, true); + binder.Bind(kv.second, kv.first, name_hint + "." + kv.first->name_hint, true); } for (const auto& kv : buffer_def) { - binder.BindDLTensor(kv.second, device_type, device_id, kv.first, kv.first->name_hint); + binder.BindDLTensor(kv.second, device_type, device_id, kv.first, + name_hint + "." + kv.first->name_hint); } func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc)); diff --git a/tests/python/relay/aot/test_crt_forward_declarations.py b/tests/python/relay/aot/test_crt_forward_declarations.py index d1f725848e7e..17af7a5d682d 100644 --- a/tests/python/relay/aot/test_crt_forward_declarations.py +++ b/tests/python/relay/aot/test_crt_forward_declarations.py @@ -160,8 +160,8 @@ def test_internal_calls(interface_api, use_unpacked_api, test_runner): lib_mod = compiled_models[0].executor_factory.lib.imported_modules[0] main_source = lib_mod.get_source() - assert main_source.count("tvmgen_default_fused_nn_contrib_depthwise_conv2d_NCHWc") == 2 - assert main_source.count("tvmgen_default_fused_layout_transform") == 6 + assert main_source.count("int32_t tvmgen_default_fused_nn_contrib_depthwise_conv2d_NCHWc") == 1 + assert main_source.count("int32_t tvmgen_default_fused_layout_transform") == 3 @tvm.testing.requires_corstone300