Skip to content

Commit

Permalink
Add custom tolerance option for onnx_test_runner (microsoft#13683)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Chen <[email protected]>

### Description
Add a `-t` option for `onnx_test_runner` to allow users to specify
custom tolerance values when running ONNX models.


### Motivation and Context
For some backends, the default tolerance of 1-e5 is too tight to pass
accuracy checks with ONNX model zoo reference values, especially if only
one or two values are mismatched. Having a custom option will allow
different backends to specify their own custom tolerance when running
these models.

Signed-off-by: Kevin Chen <[email protected]>
  • Loading branch information
kevinch-nv authored Jan 24, 2023
1 parent 7b6d880 commit 81120e9
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
13 changes: 13 additions & 0 deletions onnxruntime/core/platform/path_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ using PATH_CHAR_TYPE = ORTCHAR_T;
template <typename T>
long OrtStrtol(const T* nptr, T** endptr);

template <typename T>
double OrtStrtod(const T* nptr, T** endptr);

/**
* Convert a C string to ssize_t(or ptrdiff_t)
* @return the converted integer value.
Expand Down Expand Up @@ -85,6 +88,16 @@ inline long OrtStrtol<wchar_t>(const wchar_t* nptr, wchar_t** endptr) {
return wcstol(nptr, endptr, 10);
}

template <>
inline double OrtStrtod<char>(const char* nptr, char** endptr) {
return strtod(nptr, endptr);
}

template <>
inline double OrtStrtod<wchar_t>(const wchar_t* nptr, wchar_t** endptr) {
return wcstod(nptr, endptr);
}

namespace onnxruntime {

/**
Expand Down
23 changes: 20 additions & 3 deletions onnxruntime/test/onnx/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ void usage() {
"\t-p: Pause after launch, can attach debugger and continue\n"
"\t-x: Use parallel executor, default (without -x): sequential executor.\n"
"\t-d [device_id]: Specifies the device id for multi-device (e.g. GPU). The value should > 0\n"
"\t-t: Specify custom relative tolerance values for output value comparison. default: 1e-5\n"
"\t-a: Specify custom absolute tolerance values for output value comparison. default: 1e-5\n"
"\t-i: Specify EP specific runtime options as key value pairs. Different runtime options available are: \n"
"\t [SNPE only] [runtime]: SNPE runtime, options: 'CPU', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n"
"\t [SNPE only] [priority]: execution priority, options: 'low', 'normal'. \n"
Expand All @@ -60,9 +62,13 @@ void usage() {
OrtGetApiBase()->GetVersionString());
}

static TestTolerances LoadTestTolerances(bool enable_cuda, bool enable_openvino) {
static TestTolerances LoadTestTolerances(bool enable_cuda, bool enable_openvino, bool useCustom, double atol, double rtol) {
TestTolerances::Map absolute_overrides;
TestTolerances::Map relative_overrides;
if (useCustom)
{
return TestTolerances(atol, rtol, absolute_overrides, relative_overrides);
}
std::ifstream overrides_ifstream(ConcatPathComponent<ORTCHAR_T>(
ORT_TSTR("testdata"), ORT_TSTR("onnx_backend_test_series_overrides.jsonc")));
if (!overrides_ifstream.good()) {
Expand Down Expand Up @@ -142,6 +148,9 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
bool enable_rocm = false;
bool enable_migraphx = false;
bool enable_xnnpack = false;
bool override_tolerance = false;
double atol = 1e-5;
double rtol = 1e-5;
int device_id = 0;
GraphOptimizationLevel graph_optimization_level = ORT_ENABLE_ALL;
bool user_graph_optimization_level_set = false;
Expand All @@ -154,7 +163,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
bool pause = false;
{
int ch;
while ((ch = getopt(argc, argv, ORT_TSTR("Ac:hj:Mn:r:e:xvo:d:i:pz"))) != -1) {
while ((ch = getopt(argc, argv, ORT_TSTR("Ac:hj:Mn:r:e:t:a:xvo:d:i:pz"))) != -1) {
switch (ch) {
case 'A':
enable_cpu_mem_arena = false;
Expand Down Expand Up @@ -225,6 +234,14 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
return -1;
}
break;
case 't':
override_tolerance = true;
rtol = OrtStrtod<PATH_CHAR_TYPE>(optarg, nullptr);
break;
case 'a':
override_tolerance = true;
atol = OrtStrtod<PATH_CHAR_TYPE>(optarg, nullptr);
break;
case 'x':
execution_mode = ExecutionMode::ORT_PARALLEL;
break;
Expand Down Expand Up @@ -589,7 +606,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");

std::vector<ITestCase*> tests;
LoadTests(data_dirs, whitelisted_test_cases,
LoadTestTolerances(enable_cuda, enable_openvino),
LoadTestTolerances(enable_cuda, enable_openvino, override_tolerance, atol, rtol),
all_disabled_tests,
[&owned_tests, &tests](std::unique_ptr<ITestCase> l) {
tests.push_back(l.get());
Expand Down

0 comments on commit 81120e9

Please sign in to comment.