diff --git a/apps/cpp_rtvm/tvm_runner.h b/apps/cpp_rtvm/tvm_runner.h index 926e009c4c2e..47717c3ecfd2 100644 --- a/apps/cpp_rtvm/tvm_runner.h +++ b/apps/cpp_rtvm/tvm_runner.h @@ -38,8 +38,7 @@ namespace runtime { /*! * \brief various meta information related to the compiled TVM model. */ -typedef struct { - public: +typedef struct _TVMMetaInfo { int n_inputs; int n_outputs; std::map, std::string>> input_info; diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake index 116c6e4fc72e..35e246cd8f16 100644 --- a/cmake/modules/LibInfo.cmake +++ b/cmake/modules/LibInfo.cmake @@ -120,6 +120,7 @@ function(add_lib_info src_file) TVM_INFO_USE_VULKAN="${USE_VULKAN}" TVM_INFO_USE_CLML="${USE_CLML}" TVM_INFO_USE_CLML_GRAPH_EXECUTOR="${USE_CLML_GRAPH_EXECUTOR}" + TVM_INFO_USE_TVM_CLML_VERSION="${CLML_VERSION_MAJOR}" TVM_INFO_USE_UMA="${USE_UMA}" TVM_INFO_USE_VERILATOR="${USE_VERILATOR}" TVM_INFO_USE_CCACHE="${USE_CCACHE}" diff --git a/python/tvm/relay/op/contrib/clml.py b/python/tvm/relay/op/contrib/clml.py index 7ddb77ce75b7..608c8a2a1b73 100644 --- a/python/tvm/relay/op/contrib/clml.py +++ b/python/tvm/relay/op/contrib/clml.py @@ -37,7 +37,7 @@ def clml_sdk_version(): """Utility function to get clml version version""" - return tvm.support.libinfo().get("TVM_CLML_VERSION", 2) + return int(tvm.support.libinfo().get("TVM_CLML_VERSION", 2)) def is_clml_runtime_enabled(): @@ -155,6 +155,7 @@ def alter_conv(attrs, inputs, tinfos, out_type): seq = tvm.transform.Sequential( [ transform.ConvertLayout({"nn.conv2d": ["NCHW", "OIHW"]}), + transform.ConvertLayout({"nn.conv2d_transpose": ["NCHW", "OIHW"]}), transform.AlterOpLayout(), transform.FoldConstant(), ] @@ -203,6 +204,22 @@ def conv_pattern(): pattern = pattern.optional(is_op("clip")) return pattern + def conv_transpose_pattern(): + """Create a transposed convolution pattern.""" + pattern = is_op("nn.conv2d_transpose")(wildcard(), is_constant()) + pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant())) + pattern = pattern.optional(lambda x: is_op("add")(x, is_constant())) + pattern = pattern.optional( + lambda x: is_tuple_get_item( + is_op("nn.batch_norm")( + x, is_constant(), is_constant(), is_constant(), is_constant() + ) + ) + ) + pattern = pattern.optional(is_op("nn.relu")) + pattern = pattern.optional(is_op("clip")) + return pattern + def pad_conv_pattern(): """Create a pad with convolution pattern.""" pattern = is_op("nn.pad")(wildcard(), is_constant()) @@ -300,6 +317,31 @@ def check_conv(extract): return False return True + def check_conv_transpose(extract): + """Check transposed conv pattern is supported by CLML.""" + call = extract + if isinstance(call, tvm.relay.expr.TupleGetItem): + call = call.tuple_value + elif call.op.name == "nn.relu": + call = call.args[0] + if isinstance(call, tvm.relay.expr.TupleGetItem): + call = call.tuple_value + elif call.op.name == "clip": + if call.attrs["a_min"] != 0.0 or call.attrs["a_max"] != 6.0: + return False + call = call.args[0] + if isinstance(call, tvm.relay.expr.TupleGetItem): + call = call.tuple_value + + while call.op.name != "nn.conv2d_transpose": + call = call.args[0] + + attrs = call.attrs + if attrs.data_layout != "NCHW": + return False + + return True + def check_binary_op(extract): call = extract if len(call.args[1].checked_type.shape) > 0: @@ -340,6 +382,7 @@ def check_default_op(extract): return [ ("clml.pad_conv2d", pad_conv_pattern(), check_conv), ("clml.conv2d", conv_pattern(), check_conv), + ("clml.conv2d_transpose", conv_transpose_pattern(), check_conv_transpose), ("clml.dense", dense_pattern(), check_default_op), ("clml.pad", pad_pattern(), check_pad_op), ("clml.concat", concat_pattern(), check_concat_op), diff --git a/src/relay/backend/contrib/clml/codegen.cc b/src/relay/backend/contrib/clml/codegen.cc index d8ca791ad8c4..069e11dac5ff 100644 --- a/src/relay/backend/contrib/clml/codegen.cc +++ b/src/relay/backend/contrib/clml/codegen.cc @@ -83,7 +83,7 @@ class CLMLJSONSerializer : public backend::contrib::JSONSerializer { ICHECK(comp.defined()) << "CLML JSON runtime only supports composite functions."; const std::string name = comp.value(); std::shared_ptr json_node; - if (name == "clml.conv2d" || name == "clml.pad_conv2d") { + if (name == "clml.conv2d" || name == "clml.pad_conv2d" || name == "clml.conv2d_transpose") { json_node = CreateCompositeConvJSONNode(cn); } else if (name == "clml.batch_norm") { json_node = CreateBatchNormJSONNode(cn); @@ -169,7 +169,10 @@ class CLMLJSONSerializer : public backend::contrib::JSONSerializer { current_call = current_call->args[0].as(); } // Enforce a convolution node exists at this point during traversal - ICHECK(backend::IsOp(current_call, "nn.conv2d")); + if (!backend::IsOp(current_call, "nn.conv2d") && + !backend::IsOp(current_call, "nn.conv2d_transpose")) { + LOG(FATAL) << "Can't find primary op in Convolution node"; + } nodes.conv = current_call; if (!current_call->args.empty() && current_call->args[0]->IsInstance()) { current_call = current_call->args[0].as(); @@ -189,22 +192,27 @@ class CLMLJSONSerializer : public backend::contrib::JSONSerializer { std::shared_ptr CreateCompositeConvJSONNode(const CallNode* cn) { CompositeConvNode nodes = UnpackCompositeConvolution(cn); - const auto* conv_attr = nodes.conv->attrs.as(); - ICHECK(conv_attr); - std::string name; std::string name_prefix = "nn"; - - // Distinguish between normal and depth-wise convolution - if (conv_attr->channels.defined() && - tvm::tir::ExprDeepEqual()(conv_attr->channels, conv_attr->groups) && - conv_attr->groups != 1) { - name = "depthwise_conv2d"; - ICHECK(conv_attr->kernel_layout == "IOHW") - << "Kernel layout must be IHWO, has the module been pre-processed correctly?"; - } else { - name = "conv2d"; - ICHECK(conv_attr->kernel_layout == "OIHW") + if (backend::IsOp(nodes.conv, "nn.conv2d")) { + const auto* conv_attr = nodes.conv->attrs.as(); + ICHECK(conv_attr); + if (conv_attr->channels.defined() && + tvm::tir::ExprDeepEqual()(conv_attr->channels, conv_attr->groups) && + conv_attr->groups != 1) { + name = "depthwise_conv2d"; + ICHECK(conv_attr->kernel_layout == "IOHW") + << "Kernel layout must be IHWO, has the module been pre-processed correctly?"; + } else { + name = "conv2d"; + ICHECK(conv_attr->kernel_layout == "OIHW") + << "Kernel layout must be OHWI, has the module been pre-processed correctly?"; + } + } else if (backend::IsOp(nodes.conv, "nn.conv2d_transpose")) { + name = "conv2d_transpose"; + const auto* conv_transpose_attr = nodes.conv->attrs.as(); + ICHECK(conv_transpose_attr); + ICHECK(conv_transpose_attr->kernel_layout == "OIHW") << "Kernel layout must be OHWI, has the module been pre-processed correctly?"; } diff --git a/src/runtime/contrib/clml/clml_runtime.cc b/src/runtime/contrib/clml/clml_runtime.cc index 992665ee697c..7c716e68763b 100644 --- a/src/runtime/contrib/clml/clml_runtime.cc +++ b/src/runtime/contrib/clml/clml_runtime.cc @@ -156,12 +156,12 @@ class CLMLRuntime : public JSONRuntimeBase { cl_int majorVersions[MAX_VERSIONS]; cl_int minorVersions[MAX_VERSIONS]; cl_uint numVersions = 0; - result = clQueryMLInterfaceVersionsQCOM(NULL, NULL, 0, &numVersions); + result = clQueryMLInterfaceVersionsQCOM(nullptr, nullptr, 0, &numVersions); ICHECK(result == CL_SUCCESS) << "clQueryMLInterfaceVersionsQCOM:" << result; ICHECK(numVersions > 0u); ICHECK(numVersions <= MAX_VERSIONS); - result = clQueryMLInterfaceVersionsQCOM(majorVersions, minorVersions, numVersions, NULL); + result = clQueryMLInterfaceVersionsQCOM(majorVersions, minorVersions, numVersions, nullptr); ICHECK(result == CL_SUCCESS) << "clQueryMLInterfaceVersionsQCOM:" << result; for (cl_uint i = 0; i < numVersions; ++i) { @@ -171,7 +171,7 @@ class CLMLRuntime : public JSONRuntimeBase { break; } } - ICHECK(h_ClmlIntf != NULL) + ICHECK(h_ClmlIntf != nullptr) << "clGetMLInterfaceVxQCOM:" << result << " Perhaps there is mispatch between CLML SDK version to target supported version:" << majorVersions[numVersions - 1]; @@ -239,24 +239,24 @@ class CLMLRuntime : public JSONRuntimeBase { void CopyDataToCLMLTensor(std::shared_ptr tensor, void* data, cl_ml_tensor_layout_qcom layout = CL_TENSOR_LAYOUT_NCHW_QCOM) { cl_int result = 0; - cl_event evt = NULL; + cl_event evt = nullptr; result = h_ClmlIntf->clEnqueueWriteMLTensorDataQCOM(workspace->GetQueue(tentry->device), data, layout, tensor->tensor, tensor->memory, - 0, // n waitlist - NULL, // waitlist - &evt); // event - ICHECK((evt != NULL) && result == CL_SUCCESS) << "clEnqueueWriteMLTensorDataQCOM:" << result; + 0, // n waitlist + nullptr, // waitlist + &evt); // event + ICHECK((evt != nullptr) && result == CL_SUCCESS) << "clEnqueueWriteMLTensorDataQCOM:" << result; } void CopyDataFromCLMLTensor(std::shared_ptr tensor, void* data, cl_ml_tensor_layout_qcom layout = CL_TENSOR_LAYOUT_NCHW_QCOM) { cl_int result = 0; - cl_event readEvent = NULL; + cl_event readEvent = nullptr; // Read the output tensor result = h_ClmlIntf->clEnqueueReadMLTensorDataQCOM(workspace->GetQueue(tentry->device), tensor->tensor, tensor->memory, data, layout, 0, // n waitlist - NULL, // waitlist + nullptr, // waitlist &readEvent); // event ICHECK(result == CL_SUCCESS) << "clEnqueueReadMLTensorDataQCOM:" << result; @@ -289,7 +289,7 @@ class CLMLRuntime : public JSONRuntimeBase { } else if (kDLOpenCL == data_entry_[eid]->device.device_type) { layer_.in_placeholder[i]->memory = static_cast( ((cl::BufferDescriptor*)const_cast(data_entry_[eid])->data)->buffer); - cl_event cpy_evt = NULL; + cl_event cpy_evt = nullptr; cl_event* evt = &cpy_evt; if (workspace->IsProfiling(tentry->device)) { evts.resize(evts.size() + 1); @@ -297,7 +297,7 @@ class CLMLRuntime : public JSONRuntimeBase { } result = h_ClmlIntf->clEnqueueCopyMLTensorDataQCOM( queue, layer_.in_placeholder[i]->tensor, layer_.in_placeholder[i]->memory, - layer_.inputs[i]->tensor, layer_.inputs[i]->memory, 0, NULL, evt); + layer_.inputs[i]->tensor, layer_.inputs[i]->memory, 0, nullptr, evt); ICHECK(result == CL_SUCCESS) << "clEnqueueCopyMLTensorDataQCOM:" << result; } else { DLDataType tvm_dtype = const_cast(data_entry_[eid])->dtype; @@ -326,14 +326,14 @@ class CLMLRuntime : public JSONRuntimeBase { cl_event* evt = &(evts.back()); result = h_ClmlIntf->clEnqueueMLOpQCOM(queue, this->layer_.function[i], - this->layer_.descriptorSet, 0, NULL, evt); + this->layer_.descriptorSet, 0, nullptr, evt); t->Stop(); duration += t->SyncAndGetElapsedNanos(); LOG(WARNING) << "Layer:" << this->layer_.layer_names[i] << " Duration:" << t->SyncAndGetElapsedNanos(); } else { result = h_ClmlIntf->clEnqueueMLOpQCOM(queue, this->layer_.function[i], - this->layer_.descriptorSet, 0, NULL, NULL); + this->layer_.descriptorSet, 0, nullptr, nullptr); } ICHECK(result == CL_SUCCESS) << "clEnqueueMLOpQCOM:" << result; } @@ -354,7 +354,7 @@ class CLMLRuntime : public JSONRuntimeBase { } else if (kDLOpenCL == data_entry_[eid]->device.device_type) { layer_.out_placeholder[i]->memory = static_cast( ((cl::BufferDescriptor*)const_cast(data_entry_[eid])->data)->buffer); - cl_event cpy_evt = NULL; + cl_event cpy_evt = nullptr; cl_event* evt = &cpy_evt; if (workspace->IsProfiling(tentry->device)) { evts.resize(evts.size() + 1); @@ -362,7 +362,7 @@ class CLMLRuntime : public JSONRuntimeBase { } result = h_ClmlIntf->clEnqueueCopyMLTensorDataQCOM( queue, layer_.outputs[i]->tensor, layer_.outputs[i]->memory, - layer_.out_placeholder[i]->tensor, layer_.out_placeholder[i]->memory, 0, NULL, evt); + layer_.out_placeholder[i]->tensor, layer_.out_placeholder[i]->memory, 0, nullptr, evt); ICHECK(result == CL_SUCCESS) << "clEnqueueCopyMLTensorDataQCOM:" << result; } else { DLDataType tvm_dtype = const_cast(data_entry_[eid])->dtype; @@ -408,6 +408,10 @@ class CLMLRuntime : public JSONRuntimeBase { auto out = CreateConvolution2DLayer(&layer_, node, CL_CONVOLUTION_MODE_DEPTHWISE_QCOM); this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); this->layer_.func_outs.push_back(out); + } else if ("nn.conv2d_transpose" == op_name) { + auto out = CreateConvolution2DLayer(&layer_, node, CL_CONVOLUTION_MODE_TRANSPOSE_QCOM); + this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); + this->layer_.func_outs.push_back(out); } else if ("nn.relu6" == op_name) { auto out = CreateReLULayer(&layer_, node, CL_ACTIVATION_RELU6); this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); @@ -522,14 +526,14 @@ class CLMLRuntime : public JSONRuntimeBase { LOG(WARNING) << "CLML Tunning:" << this->layer_.layer_names[i]; result = h_ClmlIntf->clTuneMLOpQCOM(workspace->GetQueue(tentry->device), this->layer_.function[i], this->layer_.descriptorSet, - this->tuning_cache, NULL); + this->tuning_cache, nullptr); ICHECK(result == CL_SUCCESS) << "clTuneMLOpQCOM:" << result; } cl::OpenCLWorkspace::Global()->EnableQueueProfiling(tentry->device, false); size_t cache_len_bytes = 0; size_t len_ret = 0; - result = h_ClmlIntf->clSaveMLTuningCacheQCOM(tuning_cache, 0, NULL, &cache_len_bytes); + result = h_ClmlIntf->clSaveMLTuningCacheQCOM(tuning_cache, 0, nullptr, &cache_len_bytes); ICHECK(result == CL_SUCCESS) << "clSaveMLTuningCacheQCOM:" << result; std::vector saved_cache(cache_len_bytes, 0); @@ -574,7 +578,7 @@ class CLMLRuntime : public JSONRuntimeBase { std::vector out_tensorMemDescs; cl_ml_tensor_mem_desc_set_qcom descriptorSet; std::vector layer_names; - cl_ml_tensor_qcom unusedTensor = NULL; + cl_ml_tensor_qcom unusedTensor = nullptr; }; struct tensor_dims_t { @@ -586,11 +590,11 @@ class CLMLRuntime : public JSONRuntimeBase { size_t reqd_size = 0; cl_device_id device_id = workspace->GetCLDeviceID(workspace->GetThreadEntry()->device.device_id); - result = clGetDeviceInfo(device_id, CL_DEVICE_EXTENSIONS, 0, NULL, &reqd_size); + result = clGetDeviceInfo(device_id, CL_DEVICE_EXTENSIONS, 0, nullptr, &reqd_size); ICHECK(reqd_size > 0u && result == CL_SUCCESS) << "clGetDeviceInfo:" << result; std::vector buf(reqd_size); - result = clGetDeviceInfo(device_id, CL_DEVICE_EXTENSIONS, reqd_size, buf.data(), NULL); + result = clGetDeviceInfo(device_id, CL_DEVICE_EXTENSIONS, reqd_size, buf.data(), nullptr); ICHECK(result == CL_SUCCESS) << "clGetDeviceInfo:" << result; std::string extensions(buf.data()); @@ -608,7 +612,7 @@ class CLMLRuntime : public JSONRuntimeBase { cl_ml_tensor_desc_qcom desc = { dtype, layout, dims.n, dims.c, dims.h, dims.w, 0, CL_TENSOR_DIMENSIONS_4D_QCOM, { 0 }}; result = - h_ClmlIntf->clCreateMLTensorQCOM(workspace->contexts[platform_id], NULL, &desc, &tensor); + h_ClmlIntf->clCreateMLTensorQCOM(workspace->contexts[platform_id], nullptr, &desc, &tensor); ICHECK(tensor && result == CL_SUCCESS) << "clCreateMLTensorQCOM:" << result; (void)result; return tensor; @@ -618,14 +622,14 @@ class CLMLRuntime : public JSONRuntimeBase { std::shared_ptr pTensorMemDesc) { uint32_t size = 0; cl_int result = CL_OUT_OF_HOST_MEMORY; - cl_mem buffer = NULL; + cl_mem buffer = nullptr; result = h_ClmlIntf->clGetMLTensorMemorySizeQCOM(workspace->contexts[platform_id], pTensorMemDesc->tensor, &size); ICHECK(result == CL_SUCCESS) << "clGetMLTensorMemorySizeQCOM:" << result; buffer = - clCreateBuffer(workspace->contexts[platform_id], CL_MEM_READ_WRITE, size, NULL, &result); + clCreateBuffer(workspace->contexts[platform_id], CL_MEM_READ_WRITE, size, nullptr, &result); ICHECK(result == CL_SUCCESS) << "clCreateBuffer:" << result; pTensorMemDesc->memory = buffer; @@ -803,7 +807,7 @@ class CLMLRuntime : public JSONRuntimeBase { } else { cl_ml_tensor_desc_qcom desc = {}; desc.num_dimensions = CL_TENSOR_UNUSED_QCOM; - result = h_ClmlIntf->clCreateMLTensorQCOM(workspace->contexts[platform_id], NULL, &desc, + result = h_ClmlIntf->clCreateMLTensorQCOM(workspace->contexts[platform_id], nullptr, &desc, &layer_.unusedTensor); ICHECK(layer_.unusedTensor && result == CL_SUCCESS) << "clCreateMLTensorQCOM:" << result; bias->tensor = layer_.unusedTensor; @@ -820,17 +824,17 @@ class CLMLRuntime : public JSONRuntimeBase { 0, cl_arithmetic_mode}; - cl_ml_op_qcom op = NULL; + cl_ml_op_qcom op = nullptr; if (!has_bn) { if (!has_act) { result = h_ClmlIntf->clCreateMLOpConvolutionForwardQCOM( - workspace->contexts[platform_id], 0, &conv_desc, input->tensor, weight->tensor, - bias->tensor, output->tensor, &op, NULL); + workspace->contexts[platform_id], nullptr, &conv_desc, input->tensor, weight->tensor, + bias->tensor, output->tensor, &op, nullptr); ICHECK(op && result == CL_SUCCESS) << "Convolution Error:" << result; } else { result = h_ClmlIntf->clCreateMLOpFusedConvolutionActivationForwardQCOM( - workspace->contexts[platform_id], 0, &conv_desc, &act_desc, input->tensor, - weight->tensor, bias->tensor, NULL, output->tensor, &op, tuning_cache); + workspace->contexts[platform_id], nullptr, &conv_desc, &act_desc, input->tensor, + weight->tensor, bias->tensor, nullptr, output->tensor, &op, tuning_cache); ICHECK(op && result == CL_SUCCESS) << "Convolution Error:" << result; } layer_.func_ins.push_back(input); @@ -857,15 +861,15 @@ class CLMLRuntime : public JSONRuntimeBase { cl_ml_op_batchnorm_desc_qcom bn_desc = {CL_BATCHNORM_MODE_SPATIAL_QCOM, cl_arithmetic_mode}; if (!has_act) { result = h_ClmlIntf->clCreateMLOpFusedConvolutionBatchNormForwardQCOM( - workspace->contexts[platform_id], 0, &conv_desc, &bn_desc, input->tensor, + workspace->contexts[platform_id], nullptr, &conv_desc, &bn_desc, input->tensor, weight->tensor, bias->tensor, output->tensor, bn_mean->tensor, bn_var->tensor, bn_scale->tensor, bn_bias->tensor, &op, tuning_cache); ICHECK(op && result == CL_SUCCESS) << "Convolution Error:" << result; } else { result = h_ClmlIntf->clCreateMLOpFusedConvolutionBatchNormActivationForwardQCOM( - workspace->contexts[platform_id], 0, &conv_desc, &bn_desc, &act_desc, input->tensor, - weight->tensor, bias->tensor, output->tensor, NULL, bn_mean->tensor, bn_var->tensor, - bn_scale->tensor, bn_bias->tensor, &op, tuning_cache); + workspace->contexts[platform_id], nullptr, &conv_desc, &bn_desc, &act_desc, + input->tensor, weight->tensor, bias->tensor, output->tensor, nullptr, bn_mean->tensor, + bn_var->tensor, bn_scale->tensor, bn_bias->tensor, &op, tuning_cache); ICHECK(op && result == CL_SUCCESS) << "Convolution Error:" << result; } @@ -885,7 +889,7 @@ class CLMLRuntime : public JSONRuntimeBase { CachedLayer* layer, const JSONGraphNode& node, cl_activation_function_qcom clml_act_type = CL_ACTIVATION_RELU) { cl_int result = 0; - cl_ml_op_qcom op = NULL; + cl_ml_op_qcom op = nullptr; DLDataType tvm_dtype = node.GetOpDataType()[0]; cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype); @@ -898,12 +902,12 @@ class CLMLRuntime : public JSONRuntimeBase { cl_ml_tensor_desc_qcom desc = {}; desc.num_dimensions = CL_TENSOR_UNUSED_QCOM; - result = h_ClmlIntf->clCreateMLTensorQCOM(workspace->contexts[platform_id], NULL, &desc, + result = h_ClmlIntf->clCreateMLTensorQCOM(workspace->contexts[platform_id], nullptr, &desc, &layer_.unusedTensor); ICHECK(layer_.unusedTensor && result == CL_SUCCESS) << ":" << result; result = h_ClmlIntf->clCreateMLOpActivationForwardQCOM( - workspace->contexts[platform_id], 0, &act_desc, input->tensor, layer_.unusedTensor, + workspace->contexts[platform_id], nullptr, &act_desc, input->tensor, layer_.unusedTensor, output->tensor, &op, tuning_cache); ICHECK(op && result == CL_SUCCESS) << "Activation Error:" << result; @@ -922,13 +926,20 @@ class CLMLRuntime : public JSONRuntimeBase { std::shared_ptr CreateBatchNormLayer(CachedLayer* layer, const JSONGraphNode& node) { cl_int result = 0; - cl_ml_op_qcom op = NULL; + cl_ml_op_qcom op = nullptr; DLDataType tvm_dtype = node.GetOpDataType()[0]; cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype); auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype); int axis = std::stoi(node.GetAttr>("axis")[0]); + float epsilon = std::stof(node.GetAttr>("epsilon")[0]); + + std::vector opProperties; + opProperties.push_back(CL_ML_BATCH_NORM_OP_EPSILON_QCOM); + opProperties.push_back(*reinterpret_cast(&epsilon)); + opProperties.push_back(CL_ML_OP_PROPERTY_LIST_END_QCOM); + auto bn_dims = get_tensor_dims(nodes_[node.GetInputs()[1].id_]); std::vector bn_shape = {1, 1, 1, 1}; bn_shape[axis] = bn_dims.n; @@ -950,8 +961,9 @@ class CLMLRuntime : public JSONRuntimeBase { cl_ml_op_batchnorm_desc_qcom bn_desc = {CL_BATCHNORM_MODE_SPATIAL_QCOM, cl_arithmetic_mode}; result = h_ClmlIntf->clCreateMLOpBatchNormForwardQCOM( - workspace->contexts[platform_id], 0, &bn_desc, input->tensor, bn_mean->tensor, - bn_var->tensor, bn_scale->tensor, bn_bias->tensor, output->tensor, &op, tuning_cache); + workspace->contexts[platform_id], opProperties.data(), &bn_desc, input->tensor, + bn_mean->tensor, bn_var->tensor, bn_scale->tensor, bn_bias->tensor, output->tensor, &op, + tuning_cache); ICHECK(op && result == CL_SUCCESS) << "Batchnorm Error:" << result; layer->function.push_back(op); @@ -970,7 +982,7 @@ class CLMLRuntime : public JSONRuntimeBase { std::shared_ptr CreatePoolingLayer(CachedLayer* layer, const JSONGraphNode& node) { cl_int result = 0; - cl_ml_op_qcom op = NULL; + cl_ml_op_qcom op = nullptr; DLDataType tvm_dtype = node.GetOpDataType()[0]; cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype); @@ -998,13 +1010,13 @@ class CLMLRuntime : public JSONRuntimeBase { }; cl_ml_tensor_desc_qcom desc = {}; - cl_ml_tensor_qcom unusedTensor = NULL; + cl_ml_tensor_qcom unusedTensor = nullptr; desc.num_dimensions = CL_TENSOR_UNUSED_QCOM; - result = h_ClmlIntf->clCreateMLTensorQCOM(workspace->contexts[platform_id], NULL, &desc, + result = h_ClmlIntf->clCreateMLTensorQCOM(workspace->contexts[platform_id], nullptr, &desc, &unusedTensor); ICHECK(unusedTensor && result == CL_SUCCESS) << ":" << result; - result = h_ClmlIntf->clCreateMLOpPoolingForwardQCOM(workspace->contexts[platform_id], 0, + result = h_ClmlIntf->clCreateMLOpPoolingForwardQCOM(workspace->contexts[platform_id], nullptr, &pool_desc, input->tensor, unusedTensor, output->tensor, &op, tuning_cache); ICHECK(op && result == CL_SUCCESS) << "Pooling Error:" << result; @@ -1025,7 +1037,7 @@ class CLMLRuntime : public JSONRuntimeBase { std::shared_ptr CreateGlobalPoolingLayer( CachedLayer* layer, const JSONGraphNode& node) { cl_int result = 0; - cl_ml_op_qcom op = NULL; + cl_ml_op_qcom op = nullptr; DLDataType tvm_dtype = node.GetOpDataType()[0]; cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype); @@ -1047,12 +1059,12 @@ class CLMLRuntime : public JSONRuntimeBase { cl_ml_tensor_desc_qcom desc = {}; desc.num_dimensions = CL_TENSOR_UNUSED_QCOM; - result = h_ClmlIntf->clCreateMLTensorQCOM(workspace->contexts[platform_id], NULL, &desc, + result = h_ClmlIntf->clCreateMLTensorQCOM(workspace->contexts[platform_id], nullptr, &desc, &layer_.unusedTensor); ICHECK(layer_.unusedTensor && result == CL_SUCCESS) << ":" << result; result = h_ClmlIntf->clCreateMLOpPoolingForwardQCOM( - workspace->contexts[platform_id], 0, &pool_desc, input->tensor, layer_.unusedTensor, + workspace->contexts[platform_id], nullptr, &pool_desc, input->tensor, layer_.unusedTensor, output->tensor, &op, tuning_cache); ICHECK(op && result == CL_SUCCESS) << "Pooling Error:" << result; @@ -1070,7 +1082,7 @@ class CLMLRuntime : public JSONRuntimeBase { std::shared_ptr CreateSoftMaxLayer(CachedLayer* layer, const JSONGraphNode& node) { cl_int result = 0; - cl_ml_op_qcom op = NULL; + cl_ml_op_qcom op = nullptr; DLDataType tvm_dtype = node.GetOpDataType()[0]; cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype); @@ -1083,8 +1095,9 @@ class CLMLRuntime : public JSONRuntimeBase { cl_ml_op_softmax_desc_qcom softmax_desc = {CL_SOFTMAX_ALGORITHM_ACCURATE_QCOM, CL_SOFTMAX_MODE_INSTANCE_QCOM, cl_arithmetic_mode}; - result = h_ClmlIntf->clCreateMLOpSoftmaxQCOM(workspace->contexts[platform_id], 0, &softmax_desc, - input->tensor, output->tensor, &op, tuning_cache); + result = h_ClmlIntf->clCreateMLOpSoftmaxQCOM(workspace->contexts[platform_id], nullptr, + &softmax_desc, input->tensor, output->tensor, &op, + tuning_cache); ICHECK(op && result == CL_SUCCESS) << "SoftMax Error:" << result; layer_.func_ins.push_back(input); @@ -1101,7 +1114,7 @@ class CLMLRuntime : public JSONRuntimeBase { std::shared_ptr CreatePadLayer(CachedLayer* layer, const JSONGraphNode& node) { cl_int result = 0; - cl_ml_op_qcom op = NULL; + cl_ml_op_qcom op = nullptr; DLDataType tvm_dtype = node.GetOpDataType()[0]; cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype); @@ -1129,7 +1142,7 @@ class CLMLRuntime : public JSONRuntimeBase { {clml_padding[0], clml_padding[1], clml_padding[2], clml_padding[3], 0, 0, 0, 0}, cl_arithmetic_mode}; - result = h_ClmlIntf->clCreateMLOpPadQCOM(workspace->contexts[platform_id], 0, &pad_desc, + result = h_ClmlIntf->clCreateMLOpPadQCOM(workspace->contexts[platform_id], nullptr, &pad_desc, input->tensor, output->tensor, &op, tuning_cache); ICHECK(op && result == CL_SUCCESS) << "Pad Error:" << result; @@ -1147,15 +1160,15 @@ class CLMLRuntime : public JSONRuntimeBase { std::shared_ptr CreateBatchFlattenLayer( CachedLayer* layer, const JSONGraphNode& node) { cl_int result = 0; - cl_ml_op_qcom op = NULL; + cl_ml_op_qcom op = nullptr; DLDataType tvm_dtype = node.GetOpDataType()[0]; cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype); auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype); - result = h_ClmlIntf->clCreateMLOpReshapeQCOM(workspace->contexts[platform_id], 0, input->tensor, - output->tensor, &op, tuning_cache); + result = h_ClmlIntf->clCreateMLOpReshapeQCOM(workspace->contexts[platform_id], nullptr, + input->tensor, output->tensor, &op, tuning_cache); ICHECK(op && result == CL_SUCCESS) << "Reshape Error:" << result; layer_.func_ins.push_back(input); @@ -1172,15 +1185,15 @@ class CLMLRuntime : public JSONRuntimeBase { std::shared_ptr CreateReshapeLayer(CachedLayer* layer, const JSONGraphNode& node) { cl_int result = 0; - cl_ml_op_qcom op = NULL; + cl_ml_op_qcom op = nullptr; DLDataType tvm_dtype = node.GetOpDataType()[0]; cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype); auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype); - result = h_ClmlIntf->clCreateMLOpReshapeQCOM(workspace->contexts[platform_id], 0, input->tensor, - output->tensor, &op, tuning_cache); + result = h_ClmlIntf->clCreateMLOpReshapeQCOM(workspace->contexts[platform_id], nullptr, + input->tensor, output->tensor, &op, tuning_cache); ICHECK(op && result == CL_SUCCESS) << "Reshape Error:" << result; layer_.func_ins.push_back(input); @@ -1198,7 +1211,7 @@ class CLMLRuntime : public JSONRuntimeBase { std::shared_ptr CreateConcatLayer(CachedLayer* layer, const JSONGraphNode& node) { cl_int result = 0; - cl_ml_op_qcom op = NULL; + cl_ml_op_qcom op = nullptr; std::vector input_ = node.GetInputs(); DLDataType tvm_dtype = node.GetOpDataType()[0]; cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); @@ -1214,8 +1227,9 @@ class CLMLRuntime : public JSONRuntimeBase { } cl_ml_op_concat_desc_qcom concatDesc = {axis, (cl_uint)inputSize, cl_arithmetic_mode}; - result = h_ClmlIntf->clCreateMLOpConcatQCOM(workspace->contexts[platform_id], 0, &concatDesc, - concatInputs, output->tensor, &op, tuning_cache); + result = + h_ClmlIntf->clCreateMLOpConcatQCOM(workspace->contexts[platform_id], nullptr, &concatDesc, + concatInputs, output->tensor, &op, tuning_cache); ICHECK(op && result == CL_SUCCESS) << "Concat Error:" << result; layer->function.push_back(op); @@ -1234,7 +1248,7 @@ class CLMLRuntime : public JSONRuntimeBase { std::shared_ptr CreateDenseLayer(CachedLayer* layer, const JSONGraphNode& node) { cl_int result = 0; - cl_ml_op_qcom op = NULL; + cl_ml_op_qcom op = nullptr; DLDataType tvm_dtype = node.GetOpDataType()[0]; cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype); @@ -1254,7 +1268,7 @@ class CLMLRuntime : public JSONRuntimeBase { } else { cl_ml_tensor_desc_qcom desc = {}; desc.num_dimensions = CL_TENSOR_UNUSED_QCOM; - result = h_ClmlIntf->clCreateMLTensorQCOM(workspace->contexts[platform_id], NULL, &desc, + result = h_ClmlIntf->clCreateMLTensorQCOM(workspace->contexts[platform_id], nullptr, &desc, &layer_.unusedTensor); ICHECK(layer_.unusedTensor && result == CL_SUCCESS) << "clCreateMLTensorQCOM:" << result; bias->tensor = layer_.unusedTensor; @@ -1273,8 +1287,8 @@ class CLMLRuntime : public JSONRuntimeBase { cl_arithmetic_mode}; result = h_ClmlIntf->clCreateMLOpConvolutionForwardQCOM( - workspace->contexts[platform_id], 0, &conv_desc, input->tensor, weight->tensor, - bias->tensor, output->tensor, &op, NULL); + workspace->contexts[platform_id], nullptr, &conv_desc, input->tensor, weight->tensor, + bias->tensor, output->tensor, &op, nullptr); ICHECK(op && result == CL_SUCCESS) << "Fully Connected Error:" << result; layer->function.push_back(op); @@ -1291,7 +1305,7 @@ class CLMLRuntime : public JSONRuntimeBase { std::shared_ptr CreateClipLayer(CachedLayer* layer, const JSONGraphNode& node) { cl_int result = 0; - cl_ml_op_qcom op = NULL; + cl_ml_op_qcom op = nullptr; DLDataType tvm_dtype = node.GetOpDataType()[0]; cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype); @@ -1304,7 +1318,7 @@ class CLMLRuntime : public JSONRuntimeBase { cl_ml_op_clip_desc_qcom clip_desc = { CL_CLIP_BY_VALUE_QCOM, {{a_max}, CL_FLOAT}, {{a_min}, CL_FLOAT}, cl_arithmetic_mode}; - result = h_ClmlIntf->clCreateMLOpClipQCOM(workspace->contexts[platform_id], 0, &clip_desc, + result = h_ClmlIntf->clCreateMLOpClipQCOM(workspace->contexts[platform_id], nullptr, &clip_desc, input->tensor, output->tensor, &op, tuning_cache); ICHECK(op && result == CL_SUCCESS) << "Clip Error:" << result; @@ -1322,7 +1336,7 @@ class CLMLRuntime : public JSONRuntimeBase { std::shared_ptr CreateBinaryLayer(CachedLayer* layer, const JSONGraphNode& node) { cl_int result = 0; - cl_ml_op_qcom op = NULL; + cl_ml_op_qcom op = nullptr; DLDataType tvm_dtype = node.GetOpDataType()[0]; cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype); @@ -1346,9 +1360,9 @@ class CLMLRuntime : public JSONRuntimeBase { cl_ml_op_binary_desc_qcom add_desc = { binary_op, {{1.0}, CL_FLOAT}, {{1.0}, CL_FLOAT}, {{0.0}, CL_FLOAT}, cl_arithmetic_mode}; - result = h_ClmlIntf->clCreateMLOpBinaryQCOM(workspace->contexts[platform_id], 0, &add_desc, - input_a->tensor, input_b->tensor, output->tensor, - &op, tuning_cache); + result = h_ClmlIntf->clCreateMLOpBinaryQCOM(workspace->contexts[platform_id], nullptr, + &add_desc, input_a->tensor, input_b->tensor, + output->tensor, &op, tuning_cache); ICHECK(op && result == CL_SUCCESS) << op_name << " Node Error:" << result; layer_.func_ins.push_back(input_a); @@ -1366,7 +1380,7 @@ class CLMLRuntime : public JSONRuntimeBase { std::shared_ptr CreateDepthToSpaceLayer( CachedLayer* layer, const JSONGraphNode& node) { cl_int result = 0; - cl_ml_op_qcom op = NULL; + cl_ml_op_qcom op = nullptr; DLDataType tvm_dtype = node.GetOpDataType()[0]; cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype); @@ -1376,9 +1390,9 @@ class CLMLRuntime : public JSONRuntimeBase { cl_uint block_size = std::stoi(node.GetAttr>("block_size")[0]); cl_ml_op_depthtospace_desc_qcom dtos_desc = {block_size, cl_arithmetic_mode}; - result = - h_ClmlIntf->clCreateMLOpDepthToSpaceQCOM(workspace->contexts[platform_id], 0, &dtos_desc, - input->tensor, output->tensor, &op, tuning_cache); + result = h_ClmlIntf->clCreateMLOpDepthToSpaceQCOM(workspace->contexts[platform_id], nullptr, + &dtos_desc, input->tensor, output->tensor, + &op, tuning_cache); ICHECK(op && result == CL_SUCCESS) << "DepthToSpace Layer Error:" << result; layer_.func_ins.push_back(input); @@ -1395,7 +1409,7 @@ class CLMLRuntime : public JSONRuntimeBase { std::shared_ptr CreateResizeLayer(CachedLayer* layer, const JSONGraphNode& node) { cl_int result = 0; - cl_ml_op_qcom op = NULL; + cl_ml_op_qcom op = nullptr; DLDataType tvm_dtype = node.GetOpDataType()[0]; cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype); @@ -1405,7 +1419,7 @@ class CLMLRuntime : public JSONRuntimeBase { cl_bool align_corners = std::stoi(node.GetAttr>("align_corners")[0]); cl_ml_op_resize_bilinear_desc_qcom resize_desc = {align_corners, false, cl_arithmetic_mode}; - result = h_ClmlIntf->clCreateMLOpResizeBilinearQCOM(workspace->contexts[platform_id], 0, + result = h_ClmlIntf->clCreateMLOpResizeBilinearQCOM(workspace->contexts[platform_id], nullptr, &resize_desc, input->tensor, output->tensor, &op, tuning_cache); ICHECK(op && result == CL_SUCCESS) << "Resize Layer Error:" << result; @@ -1422,12 +1436,12 @@ class CLMLRuntime : public JSONRuntimeBase { CachedLayer layer_; // CLML Context - GET_ML_API_INTERFACE* h_ClmlIntf = NULL; - cl::OpenCLWorkspace* workspace = NULL; - cl::OpenCLThreadEntry* tentry = NULL; + GET_ML_API_INTERFACE* h_ClmlIntf = nullptr; + cl::OpenCLWorkspace* workspace = nullptr; + cl::OpenCLThreadEntry* tentry = nullptr; cl_device_id device_id; cl_platform_id platform_id; - cl_ml_tuningcache_qcom tuning_cache = NULL; + cl_ml_tuningcache_qcom tuning_cache = nullptr; bool is_tuning_run; char* tuning_file; #else diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index 009c701d2a99..96e8279b6775 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -336,6 +336,7 @@ TVM_DLL Map GetLibInfo() { {"USE_VITIS_AI", TVM_INFO_USE_VITIS_AI}, {"USE_VULKAN", TVM_INFO_USE_VULKAN}, {"USE_CLML", TVM_INFO_USE_CLML}, + {"TVM_CLML_VERSION", TVM_INFO_USE_TVM_CLML_VERSION}, {"USE_CLML_GRAPH_EXECUTOR", TVM_INFO_USE_CLML_GRAPH_EXECUTOR}, {"USE_UMA", TVM_INFO_USE_UMA}, {"USE_VERILATOR", TVM_INFO_USE_VERILATOR}, diff --git a/tests/python/contrib/test_clml/infrastructure.py b/tests/python/contrib/test_clml/infrastructure.py index be2bbc7f8a71..1b9cbdac63b5 100644 --- a/tests/python/contrib/test_clml/infrastructure.py +++ b/tests/python/contrib/test_clml/infrastructure.py @@ -135,6 +135,7 @@ def build_module(mod, target, target_host, params=None, enable_clml=True, tune_l with autotvm.apply_history_best(tune_log): with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): if enable_clml: + mod = clml.preprocess_module(mod) mod = clml.partition_for_clml(mod, params) relay.backend.te_compiler.get().clear() return relay.build(mod, target=target, target_host=target_host, params=params) @@ -210,6 +211,7 @@ def verify_codegen( if isinstance(mod, tvm.relay.expr.Call): mod = tvm.IRModule.from_expr(mod) with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): + mod = clml.preprocess_module(mod) mod = clml.partition_for_clml(mod, params) tvm_op_count = get_cpu_op_count(mod) assert tvm_op_count == tvm_ops, "Got {} TVM operators, expected {}".format( diff --git a/tests/python/contrib/test_clml/test_ops.py b/tests/python/contrib/test_clml/test_ops.py index b8177435a0dc..6cb90e7af00f 100644 --- a/tests/python/contrib/test_clml/test_ops.py +++ b/tests/python/contrib/test_clml/test_ops.py @@ -257,15 +257,114 @@ def test_conv2d(device, dtype): verify_codegen(func, exp_codegen, device, params) +def _get_conv2d_transpose_expected_codegen( + dshape, kshape, channels, kernel_size, strides, padding, dilation, dtype, output_shape +): + attrs = { + "channels": [[str(channels)]], + "data_layout": [["NCHW"]], + "kernel_layout": [["OIHW"]], + "groups": [["1"]], + "dilation": [[str(p) for p in dilation]], + "num_inputs": "2", + "num_outputs": "1", + "padding": [[str(p) for p in padding]], + "kernel_size": [[str(p) for p in kernel_size]], + "shape": [[list(output_shape)]], + "dtype": [[dtype]], + "strides": [[str(s) for s in strides]], + "out_dtype": [[""]], + "out_layout": [[""]], + "output_padding": [["0", "0"]], + } + + kshape = [kshape[1], kshape[0], kshape[2], kshape[3]] + + exp_codegen = [ + { + "op": "input", + "name": "", + "attrs": {"shape": [[list(dshape)]], "dtype": [[str(dtype)]]}, + }, + { + "op": "const", + "name": "", + "attrs": {"shape": [[list(kshape)]], "dtype": [[str(dtype)]]}, + }, + { + "op": "kernel", + "name": "nn.conv2d_transpose", + "inputs": [[0, 0, 0], [1, 0, 0]], + "attrs": attrs, + }, + ] + return exp_codegen + + +@pytest.mark.parametrize("dtype", ["float32"]) +@tvm.testing.requires_openclml +def test_conv2d_transpose(device, dtype): + trials = [ + [(1, 256, 100, 100), (256, 64, 4, 4), 64, (4, 4), (2, 2), (1, 1, 1, 1)], + [(1, 64, 200, 200), (64, 64, 4, 4), 64, (4, 4), (2, 2), (1, 1, 1, 1)], + [(1, 64, 400, 400), (64, 16, 4, 4), 16, (4, 4), (2, 2), (1, 1, 1, 1)], + ] + for (dshape, kshape, channels, kernel_size, strides, padding) in trials: + x = relay.var("input", shape=dshape, dtype=dtype) + input_arr = tvm.nd.array(np.random.uniform(-1, 1, dshape).astype(dtype)) + w = relay.var("wt", shape=kshape, dtype=dtype) + weight_arr = tvm.nd.array(np.random.uniform(-1, 1, kshape).astype(dtype)) + inputs = { + "input": input_arr, + } + params = { + "wt": weight_arr, + } + y = relay.nn.conv2d_transpose( + x, + w, + channels=channels, + kernel_size=kernel_size, + strides=strides, + padding=padding, + kernel_layout="IOHW", + data_layout="NCHW", + ) + func = relay.Function([x, w], y) + mod = IRModule.from_expr(func) + + opencl_out = build_and_run(mod, inputs, 1, params, device, enable_clml=False)[0] + clml_out = build_and_run(mod, inputs, 1, params, device, enable_clml=True)[0] + tvm.testing.assert_allclose( + clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, atol=1e-3 + ) + + args = ( + dshape, + kshape, + channels, + kernel_size, + strides, + padding, + (1, 1), + dtype, + opencl_out[0].shape, + ) + exp_codegen = _get_conv2d_transpose_expected_codegen(*args) + verify_codegen(mod, exp_codegen, device, params) + + @pytest.mark.parametrize("dtype", ["float16"]) @tvm.testing.requires_openclml def test_batchnorm(device, dtype): - if tvm.support.libinfo().get("TVM_CLML_VERSION", 2) < 3: - print("Skip due to unsupported CLML version") + if clml.clml_sdk_version() < 3: + print("Skip due to unsupported CLML version:", clml.clml_sdk_version()) return in_shape = (1, 8, 64, 64) channels = 8 + np.random.seed(8) + input_arr = tvm.nd.array(np.random.uniform(-1, 1, in_shape).astype(dtype)) inp = relay.var("a", shape=in_shape, dtype=dtype) gamma_arr = tvm.nd.array(np.random.uniform(-1, 1, (channels)).astype(dtype)) @@ -280,7 +379,7 @@ def test_batchnorm(device, dtype): params = {} - func = relay.nn.batch_norm(inp, gamma, beta, mean, variance, axis=1, epsilon=0.0001)[0] + func = relay.nn.batch_norm(inp, gamma, beta, mean, variance, axis=1, epsilon=0.0003)[0] mod = IRModule.from_expr(func) inputs = { @@ -291,7 +390,7 @@ def test_batchnorm(device, dtype): clml_out = build_and_run(mod, inputs, 1, params, device, enable_clml=True)[0] tvm.testing.assert_allclose( - clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-5, atol=1e-5 + clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, atol=1e-3 )