diff --git a/DIOPI-TEST/python/conformance/conformance_test.py b/DIOPI-TEST/python/conformance/conformance_test.py index 0082b8b36..477cd10dc 100644 --- a/DIOPI-TEST/python/conformance/conformance_test.py +++ b/DIOPI-TEST/python/conformance/conformance_test.py @@ -66,11 +66,12 @@ def allclose(cfg: dict, tensor1: np.ndarray, tensor2: np.ndarray, sum_to_compare passed = np.allclose(tensor1, tensor2, rtol, atol, True) if record: save_precision(cfg, tensor1, tensor2, passed, var_name) - if not passed and logger.level == 10: + if not passed: sum1 = tensor1.sum() sum2 = tensor2.sum() mask = np.isclose(tensor1, tensor2, rtol, atol, True) max_diff = np.abs(tensor1 - tensor2).max() + logger.info(f"Max of diff is {max_diff}.") logger.debug(f"Sum of {var_name} is {sum1}, Sum of {var_name}_ref is {sum2}, Max of diff is {max_diff}. \ \n" + f"{var_name} is {tensor1},\n{var_name}_ref is {tensor2},\nMask is {mask}\n") return passed @@ -198,7 +199,7 @@ def test_normal(mean, std, size=None): out_numpy = out_numpy.flatten() p_value = stats.kstest(out_numpy, 'norm', args=(mean, std))[1] # pytorch use 0.0001, but stats.kstest use 0.05 as threshold - assert p_value > 0.001, "failed to execute normal" + assert p_value > 0.0005, "failed to execute normal" def test_normal_(input, mean, std, shape=None): from scipy import stats @@ -208,13 +209,45 @@ def test_normal_(input, mean, std, shape=None): p_value = stats.kstest(out_numpy, 'norm', args=(mean, std))[1] assert p_value > 0.05, "failed to execute normal_" + def test_multinomial(input, num_samples, replacement): + out = F.multinomial(input, num_samples, replacement) + out_numpy = out.numpy() + has_duplicates = False + if len(out.size()) == 2: + has_duplicates = len(out_numpy[0]) != len(set(out_numpy[0])) + else: + has_duplicates = len(out_numpy) != len(set(out_numpy)) + if not replacement: + assert has_duplicates is False, "failed to execute multinomial" + out_numpy = out_numpy.flatten() + assert len(out_numpy) % num_samples == 0, "failed to execute multinomial" + + +def config_to_format_string(data, indent=0): + yaml_str = "" + if isinstance(data, dict): + for key, value in data.items(): + if value is None or value == [] or value == {} or value == "": + continue + yaml_str += "\n" + " " * indent + f"{key}: " + if key not in ["shape", "value"]: + yaml_str += config_to_format_string(value, indent + 2) + else: + yaml_str += config_to_format_string(str(value), indent + 2) + elif isinstance(data, (list, tuple)): + for item in data: + yaml_str += "\n" + " " * indent + "- " + config_to_format_string(item, indent + 2) + else: + yaml_str += f"{data}" + return yaml_str + class ConformanceTest(object): r''' Run all functions by using input, then compare_with_gen_output with saved output ''' @staticmethod - def run(func_name, model_name, filter_dtype_str_list): + def run(func_name, model_name, filter_dtype_str_list, debug_level): _cur_dir = os.path.dirname(os.path.abspath(__file__)) inputs_dir_path = os.path.join(_cur_dir, "../data/" + model_name + "/inputs") @@ -263,8 +296,18 @@ def run(func_name, model_name, filter_dtype_str_list): sum_to_compare = True if 'sorted' in kwargs and ~kwargs['sorted'] else False passed = compare_with_gen_output(output, data['cfg'], output_reference, sum_to_compare) \ if need_output else True - logger.info(f"Run diopi_functions.{cfg_func_name} succeed") \ - if passed else logger.error(f"Run diopi_functions.{cfg_func_name} failed", tag=test_tag, info=tensor_info) + if passed: + logger.info(f"Run diopi_functions.{cfg_func_name} succeed") + else: + logger.error(f"Run diopi_functions.{cfg_func_name} failed", tag=test_tag, info=tensor_info) + if debug_level > 0: + logger.error("failed config:\n%s", config_to_format_string(data['cfg'])) + if debug_level > 1: + logger.error("failed arguments:") + for key, arg in kwargs.items(): + logger.error(f"{key}: {arg}") + logger.error(f"output_reference:\n{output_reference}") + logger.error(f"output:\n{output}") except FunctionNotImplementedError as e: logger.error(f"NotImplemented: {e}") continue @@ -299,10 +342,21 @@ def run(func_name, model_name, filter_dtype_str_list): try: grad_input = eval(f"F.{cfg_func_name}_backward(**kwargs, **backward_para)") - # import pdb;pdb.set_trace() passed = compare_with_gen_output(grad_input, data['cfg'], backward_out_reference) - logger.info(f"Run diopi_functions.{cfg_func_name}_backward succeed") \ - if passed else logger.error(f"Run diopi_functions.{cfg_func_name}_backward failed", tag=test_tag, info=tensor_info) + if passed: + logger.info(f"Run diopi_functions.{cfg_func_name}_backward succeed") + else: + logger.error(f"Run diopi_functions.{cfg_func_name}_backward failed", tag=test_tag, info=tensor_info) + if debug_level > 0: + logger.error("failed config:\n%s", config_to_format_string(data['cfg'])) + if debug_level > 1: + logger.error("failed arguments:") + for key, arg in kwargs.items(): + logger.error(f"{key}: {arg}") + for key, arg in backward_para.items(): + logger.error(f"{key}: {arg}") + logger.error(f"grad_reference:\n{backward_out_reference}") + logger.error(f"grad:\n{grad_input}") write_precision(data["cfg"], cfg_func_name + '_bp', passed) except FunctionNotImplementedError as e: logger.error(f"NotImplemented: {e}") diff --git a/DIOPI-TEST/python/conformance/device_config_helper.py b/DIOPI-TEST/python/conformance/device_config_helper.py new file mode 100644 index 000000000..fcb069d27 --- /dev/null +++ b/DIOPI-TEST/python/conformance/device_config_helper.py @@ -0,0 +1,124 @@ +import copy +from .config import _must_be_the_type, _must_exist, expand_cfg_by_name + + +class Skip: + def __init__(self, value): + self.value = value + + +def _must_be_the_list_or_tuple_of_type(cfg_path: str, cfg_dict: dict, required_type, cfg_keys: list) -> None: + if isinstance(required_type, (list, tuple)): + types_str = "" + for i in required_type: + types_str += i.__name__ + types_str += ' or ' + types_str = types_str[:-4] + else: + types_str = required_type.__name__ + + err = f"key %s should be the list or tuple of {types_str} in {cfg_path}" + for key in cfg_keys: + if key in cfg_dict.keys(): + assert isinstance(cfg_dict[key], (list, tuple)), err % key + for v in cfg_dict[key]: + assert isinstance(v, required_type), err % key + + +def check_configs_format(cfgs_dict: dict): + for case_k, case_v in cfgs_dict.items(): + domain = f"device_configs.{case_k}" + _must_be_the_type(domain, case_v, list, ["dtype"]) + if "dtype" in case_v.keys(): + _must_be_the_list_or_tuple_of_type(domain, case_v, Skip, ["dtype"]) + + _must_exist(domain, case_v, ['name']) + _must_be_the_type(domain, case_v, list, ['name']) + + if "tensor_para" in case_v.keys(): + _must_be_the_type(domain, case_v, dict, ['tensor_para']) + _must_exist(domain + ".tensor_para", case_v["tensor_para"], ["args"]) + _must_be_the_type(domain + ".tensor_para", case_v["tensor_para"], + (list, tuple), ['args']) + domain_tmp = domain + ".tensor_para.args" + for arg in case_v["tensor_para"]['args']: + _must_exist(domain_tmp, arg, ["ins"]) + _must_be_the_list_or_tuple_of_type(domain_tmp, arg, Skip, ['shape', 'value', 'dtype']) + + if "para" in case_v.keys(): + _must_be_the_type(domain, case_v, dict, ['para']) + dict_obj = case_v["para"] + _must_be_the_list_or_tuple_of_type(domain + ".para", dict_obj, Skip, + [i for i in dict_obj.keys()]) + + +def expand_tensor_paras_args_by_ins(cfgs_dict): + ''' + [ + { + "ins": ['x1', 'x2'], + "shape": [(2, 3, 16), (4, 32, 7, 7)], + }, + ] + ====> + { + 'x1':{ + "ins": ['x1'], + "shape": [(2, 3, 16), (4, 32, 7, 7)], + }, + 'x2':{ + "ins": ['x2'], + "shape": [(2, 3, 16), (4, 32, 7, 7)], + }, + } + ''' + for cfg_name in cfgs_dict: + tensor_para_args = cfgs_dict[cfg_name]["tensor_para"]["args"] + tmp_tensor_para_args = {} + for arg in tensor_para_args: + assert isinstance(arg["ins"], (list, tuple)) + for in_name in arg["ins"]: + tmp_tensor_para_args[in_name] = copy.deepcopy(arg) + tmp_tensor_para_args[in_name]["ins"] = [in_name] + cfgs_dict[cfg_name]["tensor_para"]["args"] = tmp_tensor_para_args + + +def format_cfg(cases): + for case_k, case_v in cases.items(): + # set [] for defalut para, tensor_para, para + if "tensor_para" not in case_v.keys(): + case_v["tensor_para"] = {} + if "args" not in case_v["tensor_para"].keys(): + case_v["tensor_para"]["args"] = [] + if "para" not in case_v.keys(): + case_v["para"] = {} + + +def extract_value_from_skip(cfgs_dict): + for case_k, case_v in cfgs_dict.items(): + if "dtype" in case_v.keys(): + case_v["dtype"] = [x.value for x in case_v["dtype"]] + for para_k, para_v in case_v["para"].items(): + case_v["para"][para_k] = [x.value for x in para_v] + for arg_k, arg_v in case_v["tensor_para"]["args"].items(): + if "shape" in arg_v: + arg_v["shape"] = [x.value for x in arg_v["shape"]] + if "value" in arg_v: + arg_v["value"] = [x.value for x in arg_v["value"]] + if "dtype" in arg_v: + arg_v["dtype"] = [x.value for x in arg_v["dtype"]] + + +class DeviceConfig(object): + r""" + Process device config file + """ + + @staticmethod + def process_configs(cfgs_dict: dict): + check_configs_format(cfgs_dict) + cfgs_dict = expand_cfg_by_name(cfgs_dict, 'name') + format_cfg(cfgs_dict) + expand_tensor_paras_args_by_ins(cfgs_dict) + extract_value_from_skip(cfgs_dict) + return cfgs_dict diff --git a/DIOPI-TEST/python/conformance/diopi_configs.py b/DIOPI-TEST/python/conformance/diopi_configs.py index 499a9a617..472a724d4 100644 --- a/DIOPI-TEST/python/conformance/diopi_configs.py +++ b/DIOPI-TEST/python/conformance/diopi_configs.py @@ -43,6 +43,36 @@ ), ), + 'baddbmm': dict( + name=["baddbmm"], + interface=["torch"], + is_inplace=True, + dtype=[Dtype.float32, Dtype.float16, Dtype.float64], + para=dict( + beta=[1, 0.5, 0.1], + alpha=[0.1, 0.2, 0.5], + ), + tensor_para=dict( + args=[ + { + "ins": ["input"], + "shape": ((32, 64, 16), (32, 64, 32), (168, 52, 64)), + "gen_fn": Genfunc.randn, + }, + { + "ins": ["batch1"], + "shape": ((32, 64, 32), (32, 64, 8), (168, 52, 38)), + "gen_fn": Genfunc.randn, + }, + { + "ins": ["batch2"], + "shape": ((32, 32, 16), (32, 8, 32), (168, 38, 64)), + "gen_fn": Genfunc.randn, + }, + ] + ), + ), + 'conv_2d': dict( name=["conv2d"], atol=1e-3, @@ -121,6 +151,7 @@ "ins": ['input'], "shape": ((2, 4096), (64, 28, 28), (32, 64, 112, 112), (64, 3, 7, 28, 28)), + "requires_grad": [True], "dtype": [Dtype.float32, Dtype.float64], "gen_fn": Genfunc.randn, }, @@ -341,7 +372,7 @@ 'pointwise_op': dict( name=['abs', 'cos', 'erf', 'exp', 'floor', - 'neg', 'sin', 'sqrt', 'logical_not'], + 'neg', 'sin', 'sqrt', 'logical_not', 'rsqrt'], interface=['torch'], is_inplace=True, dtype=[Dtype.float16, Dtype.float32, Dtype.float64], @@ -360,7 +391,7 @@ 'pointwise_op_int_without_inplace': dict( name=['abs', 'cos', 'erf', 'exp', - 'neg', 'sin', 'sqrt', 'logical_not'], + 'neg', 'sin', 'sqrt', 'logical_not', 'rsqrt'], interface=['torch'], dtype=[Dtype.int16, Dtype.int32, Dtype.int64, Dtype.uint8, Dtype.int8], tensor_para=dict( @@ -394,7 +425,7 @@ ), 'pointwise_op_bool': dict( - name=['cos', 'erf', 'exp', 'sin', 'sqrt'], + name=['cos', 'erf', 'exp', 'sin', 'sqrt', 'rsqrt'], interface=['torch'], dtype=[Dtype.bool], tensor_para=dict( @@ -429,7 +460,7 @@ ), 'pointwise_op_abs_input': dict( - name=['log', 'log2', 'log10', 'sqrt'], + name=['log', 'log2', 'log10', 'sqrt', 'rsqrt'], interface=['torch'], is_inplace=True, dtype=[Dtype.float16, Dtype.float32, Dtype.float64], @@ -554,6 +585,23 @@ ), ), + 'silu': dict( + name=["silu"], + is_inplace=True, + tensor_para=dict( + args=[ + { + "ins": ['input'], + "requires_grad": [True], + "shape": ((182400,), (20267, 80), (8, 200, 304), + (32, 16, 1, 1), (16, 32, 130, 130)), + "dtype": [Dtype.float32, Dtype.float64], + "gen_fn": Genfunc.randn, + }, + ], + ), + ), + 'pow_float_tensor': dict( name=['pow'], interface=['torch'], @@ -2167,10 +2215,9 @@ args=[ { "ins": ['input'], - "shape": ((2, 4096), (32, 49, 256), (2, 16, 64, 64), - (1, 2304, 1, 1, 1)), + "shape": ((2, 4096), (32, 49, 256), (2, 16, 64, 64), (1, 2304, 1, 1, 1)), "dtype": [Dtype.float32, Dtype.float64], - "gen_fn": Genfunc.randn, + "gen_fn": Genfunc.positive, }, ], ), @@ -2190,7 +2237,7 @@ "shape": ((2, 4096), (32, 49, 256), (2, 16, 64, 64), (1, 2304, 1, 1, 1)), "dtype": [Dtype.float32, Dtype.float64], - "gen_fn": Genfunc.randn, + "gen_fn": Genfunc.positive, }, ], ), @@ -2209,7 +2256,7 @@ "ins": ['input'], "shape": ((32, 49, 256), (32, 16, 64, 64)), "dtype": [Dtype.float32, Dtype.float64], - "gen_fn": Genfunc.randn, + "gen_fn": Genfunc.positive, }, ], ), @@ -4294,4 +4341,24 @@ ), ), + 'multinomial': dict( + name=["multinomial"], + interface=['torch'], + no_output_ref=True, + para=dict( + num_samples=[6, 60, 200, 128], + replacement=[True, True, False, False], + ), + tensor_para=dict( + gen_fn=Genfunc.positive, + args=[ + { + "ins": ['input'], + "shape": ((8, ), (16, 64,), (128, 256,), (256, 128,)), + "dtype": [Dtype.float32, Dtype.float64, Dtype.float64, Dtype.float64], + }, + ], + ), + ), + } diff --git a/DIOPI-TEST/python/conformance/diopi_functions.py b/DIOPI-TEST/python/conformance/diopi_functions.py index e63126959..e564833a6 100644 --- a/DIOPI-TEST/python/conformance/diopi_functions.py +++ b/DIOPI-TEST/python/conformance/diopi_functions.py @@ -1,6 +1,7 @@ # Copyright (c) 2023, DeepLink. # -*- coding: UTF-8 -*- import math +import itertools from ctypes import c_float, c_double, c_int64, c_bool, c_void_p, byref, pointer from .diopi_runtime import Sizes, Scalar, Tensor, TensorHandle, compute_nhwc_stride, compute_nhwc_stride_2d, compute_nhwc_stride_3d @@ -218,10 +219,28 @@ def sigmoid(input, inplace=False) -> Tensor: return unary_op(input, inplace, 'diopiSigmoid') +def silu(input, inplace=False) -> Tensor: + return unary_op(input, inplace, 'diopiSilu') + + +def silu_backward(input, grad_outputs, **kwargs) -> Tensor: + assert len(grad_outputs) == 1, "only accept 1 gradient to do backward" + grad_input = raw_like(input) + func = check_function("diopiSiluBackward") + ret = func(input.context_handle, grad_input.tensor_handle, grad_outputs[0].tensor_handle, + input.tensor_handle) + check_returncode(ret) + return {"input": grad_input} + + def sqrt(input, inplace=False) -> Tensor: return unary_op(input, inplace, 'diopiSqrt', promote_type(input, Dtype.float32)) +def rsqrt(input, inplace=False) -> Tensor: + return unary_op(input, inplace, 'diopiRsqrt', promote_type(input, Dtype.float32)) + + def neg(input, inplace=False) -> Tensor: return unary_op(input, inplace, 'diopiNeg') @@ -374,6 +393,28 @@ def bmm(input, mat2) -> Tensor: return out +def baddbmm(input, batch1, batch2, beta, alpha, inplace=False) -> Tensor: + size1 = list(input.size()) + assert (len(size1) == 3), 'input must be 3d tensor' + size2 = list(batch1.size()) + assert (len(size2) == 3), 'batch1 must be 3d tensor' + size3 = list(batch2.size()) + assert (len(size3) == 3), 'batch2 must be 3d tensor' + assert (size2[2] == size3[1] and size1[0] == size2[0] and size1[0] == size3[0]), 'invalid args' + assert (size1[2] == size3[2] or size1[2] == 1 or size3[2] == 1), 'invalid args' + if inplace: + func = check_function("diopiBaddbmmInp") + ret = func(input.context_handle, input.tensor_handle, batch1.tensor_handle, batch2.tensor_handle, c_double(beta), c_double(alpha)) + check_returncode(ret) + return input + else: + out = raw_like(input) + func = check_function("diopiBaddbmm") + ret = func(input.context_handle, out.tensor_handle, input.tensor_handle, batch1.tensor_handle, batch2.tensor_handle, c_double(beta), c_double(alpha)) + check_returncode(ret) + return out + + def addcmul(input, tensor1, tensor2, value=1, inplace=False) -> Tensor: size1 = list(tensor1.size()) size2 = list(tensor2.size()) @@ -1070,13 +1111,33 @@ def sort(input, dim=- 1, descending=False, stable=False): vals = raw_like(input) sizeI = input.size() indices = Tensor(sizeI, glob_vars.int_type) - - stable = c_void_p() if stable is None else pointer(c_bool(stable)) - + stable_c = c_void_p() if stable is None else pointer(c_bool(stable)) func = check_function("diopiSort") ret = func(input.context_handle, vals.tensor_handle, indices.tensor_handle, - input.tensor_handle, c_int64(dim), c_bool(descending), stable) - check_returncode(ret) + input.tensor_handle, c_int64(dim), c_bool(descending), stable_c) + check_returncode(ret) + # if not stable, need to reconstruct indices and use "input[indices]" to check + if not stable: + # reconstruct the indices + lst = [] + for dim_size in input.shape: + temp_lst = [i for i in range(dim_size)] + lst.append(temp_lst) + temp_indices = list(itertools.product(*lst)) + for i in range(len(temp_indices)): + temp_indices[i] = list(temp_indices[i]) + temp_indices[i][dim] = indices.numpy().flatten()[i] + + # use input[indices] to check + temp_vals = [] + input_np = input.numpy() + for idx in temp_indices: + res = input_np + # use for loop to index since idx is a list + for i in idx: + res = res[i] + temp_vals.append(res) + return vals, temp_vals return vals, indices @@ -1772,6 +1833,16 @@ def hardtanh_backward(input, grad_outputs, min_val=-1.0, max_val=1.0, **kwargs) return {"input": grad_input} +def hardswish_backward(input, grad_outputs, **kwargs) -> Tensor: + assert len(grad_outputs) == 1, "only accept 1 gradient to do backward" + grad_input = raw_like(input) + func = check_function("diopiHardswishBackward") + ret = func(input.context_handle, grad_input.tensor_handle, grad_outputs[0].tensor_handle, + input.tensor_handle) + check_returncode(ret) + return {"input": grad_input} + + def gelu_backward(input, grad_outputs, approximate='none', **kwargs) -> Tensor: assert len(grad_outputs) == 1, "only accept 1 gradient to do backward" grad_input = raw_like(input) @@ -3468,3 +3539,15 @@ def meshgrid(tensors, shape=None): ret = func(tensors[0].context_handle, pointer(co_tensors), pointer(c_tensors), c_int64(inputsNum)) check_returncode(ret) return out + + +def multinomial(input, num_samples, replacement) -> Tensor: + call = "diopiMultinomial" + func = check_function(call) + if len(input.size()) == 2: + out = Tensor(size=(input.size()[0], num_samples), dtype=Dtype.int64) + if len(input.size()) == 1: + out = Tensor(size=(num_samples,), dtype=Dtype.int64) + ret = func(input.context_handle, out.tensor_handle, input.tensor_handle, c_int64(num_samples), c_bool(replacement)) + check_returncode(ret) + return out diff --git a/DIOPI-TEST/python/conformance/gen_data.py b/DIOPI-TEST/python/conformance/gen_data.py index 3070c46ce..d224cb265 100644 --- a/DIOPI-TEST/python/conformance/gen_data.py +++ b/DIOPI-TEST/python/conformance/gen_data.py @@ -19,6 +19,34 @@ _cur_dir = os.path.dirname(os.path.abspath(__file__)) +def check_device_para_and_tensor_para(cfg_dict, device_cfg_dict): + para_dict = cfg_dict["para"] + device_para_dict = device_cfg_dict["para"] + for dk, dv in device_para_dict.items(): + if dk in para_dict: + v = para_dict[dk] + for x in dv: + if x not in v: + logger.warn(f"Para {x} of key {dk} in device_configs not found in diopi_configs. Ignored.") + + args_list = cfg_dict["tensor_para"]["args"] + device_tensor_paras_dict = device_cfg_dict["tensor_para"]["args"] + for input in device_tensor_paras_dict.keys(): + in_found = False + for args in args_list: + if "ins" in args: + ins = args["ins"] + if input in ins: + in_found = True + for key in ["dtype", "shape", "value"]: + if key in device_tensor_paras_dict[input] and key in args: + for dv in device_tensor_paras_dict[input][key]: + if dv not in args[key]: + logger.warn(f"Tensor para {dv} of key {key} in device_configs found in diopi_configs for ins {ins}. Ignored.") + if not in_found: + logger.warn(f"Input name {input} in device_configs not found in diopi_configs. Ignored.") + + def expand_para(para_dict: dict, paras_list: list): r''' dict(a = [1,2], b = [11,22]) @@ -109,8 +137,28 @@ def expand_cfg_by_para(cfg_dict: dict): return paras_list, tensor_paras_list -def expand_cfg_all(paras_list, tensor_paras_list, cfg_dict, filter_dtype_list) -> list: +def expand_cfg_all(paras_list, tensor_paras_list, cfg_dict, filter_dtype_list, device_config) -> list: cfg_expand_list = [] + + if device_config is not None: + skipped_index = [] + assert len(paras_list) == len(tensor_paras_list) + device_paras = device_config["para"] + device_tensor_paras = device_config["tensor_para"]["args"] + for idx, paras in enumerate(paras_list): + for skipped_para_name in device_paras: + if skipped_para_name in paras and paras[skipped_para_name] in device_paras[skipped_para_name]: + skipped_index.append(idx) + for idx, tensor_paras in enumerate(tensor_paras_list): + for tensor_para in tensor_paras: + if tensor_para["ins"] in device_tensor_paras: + if ("value" in tensor_para and "value" in device_tensor_paras[tensor_para["ins"]] and tensor_para["value"] in device_tensor_paras[tensor_para["ins"]]["value"]) or \ + ("shape" in tensor_para and "shape" in device_tensor_paras[tensor_para["ins"]] and tensor_para["shape"] in device_tensor_paras[tensor_para["ins"]]["shape"]): + if idx not in skipped_index: + skipped_index.append(idx) + paras_list = [paras_list[i] for i in range(len(paras_list)) if i not in skipped_index] + tensor_paras_list = [tensor_paras_list[i] for i in range(len(tensor_paras_list)) if i not in skipped_index] + if len(tensor_paras_list) != 0: arg_dtype_num = 0 for arg in cfg_dict["tensor_para"]["args"]: @@ -131,7 +179,11 @@ def expand_cfg_all(paras_list, tensor_paras_list, cfg_dict, filter_dtype_list) - for arg in tmp_cfg_dict["tensor_para"]["args"]: if arg.get("dtype") is not None: entry_dtype = arg["dtype"][i] - if entry_dtype in filter_dtype_list: + arg_filter_dtype_list = [] + if device_config is not None: + if arg["ins"] in device_tensor_paras and "dtype" in device_tensor_paras[arg["ins"]]: + arg_filter_dtype_list = device_tensor_paras[arg["ins"]]["dtype"] + if entry_dtype in filter_dtype_list or entry_dtype in arg_filter_dtype_list: filter_dtype = True break else: @@ -155,9 +207,11 @@ def expand_cfg_all(paras_list, tensor_paras_list, cfg_dict, filter_dtype_list) - return cfg_expand_list -def expand_cfg_by_all_options(cfg_dict: dict, filter_dtype_list: list) -> list: +def expand_cfg_by_all_options(cfg_dict: dict, filter_dtype_list: list, device_config: dict = None) -> list: + if device_config: + check_device_para_and_tensor_para(cfg_dict, device_config) paras_list, tensor_paras_list = expand_cfg_by_para(cfg_dict) - cfg_expand_list = expand_cfg_all(paras_list, tensor_paras_list, cfg_dict, filter_dtype_list) + cfg_expand_list = expand_cfg_all(paras_list, tensor_paras_list, cfg_dict, filter_dtype_list, device_config) return cfg_expand_list @@ -311,14 +365,23 @@ class GenInputData(object): ''' @staticmethod - def run(func_name, model_name, filter_dtype_str_list): - + def run(func_name, model_name, filter_dtype_str_list, impl_folder): if model_name != "": diopi_config = "model_config." + model_name + "_config" configs = Config.process_configs(eval(diopi_config)) else: configs = Config.process_configs(diopi_configs) + src_path = os.path.join(impl_folder, "device_configs.py") + use_device_configs = os.path.isfile(src_path) + if use_device_configs: + dst_path = os.path.join(_cur_dir, "device_configs.py") + os.symlink(src_path, dst_path) + from .device_configs import device_configs + os.unlink(dst_path) + from .device_config_helper import DeviceConfig + device_configs = DeviceConfig.process_configs(device_configs) + inputs_dir_path = os.path.join(_cur_dir, "../data/" + model_name + "/inputs") if not os.path.exists(inputs_dir_path): os.makedirs(inputs_dir_path) @@ -331,7 +394,19 @@ def run(func_name, model_name, filter_dtype_str_list): continue logger.info(f"Generate benchmark input data for diopi_functions.{cfg_func_name}") filter_dtype_list = get_filter_dtype_list(filter_dtype_str_list) - cfg_expand_list = expand_cfg_by_all_options(configs[cfg_name], filter_dtype_list) + + if use_device_configs and cfg_name in device_configs: + device_config = device_configs[cfg_name] + if 'dtype' in device_config: + filter_dtype_list.extend(x for x in device_config['dtype'] if x not in filter_dtype_list) + tol_keys_list = ['atol', 'rtol', 'atol_half', 'rtol_half'] + for key in tol_keys_list: + if key in device_config: + configs[cfg_name][key] = device_config[key] + cfg_expand_list = expand_cfg_by_all_options(configs[cfg_name], filter_dtype_list, device_config) + else: + cfg_expand_list = expand_cfg_by_all_options(configs[cfg_name], filter_dtype_list) + cfg_counter += len(cfg_expand_list) gen_and_dump_data(inputs_dir_path, cfg_name, cfg_expand_list, cfg_save_dict) diff --git a/DIOPI-TEST/python/main.py b/DIOPI-TEST/python/main.py index 0713db3b8..b0eb2b8b6 100644 --- a/DIOPI-TEST/python/main.py +++ b/DIOPI-TEST/python/main.py @@ -24,6 +24,10 @@ def parse_args(): help='Whether to use nhwc layout for 3-dim Tensor') parser.add_argument('--four_bytes', action='store_true', help='Whether to use 4-bytes data type for partial tests') + parser.add_argument('--impl_folder', type=str, default='', + help='folder to find device configs') + parser.add_argument('--failure_debug_level', type=int, default=0, + help='Whether to print debug information when failing the test. 0 for printing nothing, 1 for printing config, 2 for printing config, inputs and outputs') args = parser.parse_args() return args @@ -58,13 +62,13 @@ def parse_args(): if args.mode == 'gen_data': import conformance.gen_data as gd - gd.GenInputData.run(args.fname, args.model_name.lower(), args.filter_dtype) + gd.GenInputData.run(args.fname, args.model_name.lower(), args.filter_dtype, args.impl_folder) gd.GenOutputData.run(args.fname, args.model_name.lower(), args.filter_dtype) if args.model_name != '': logger.info(f"the op list of {args.model_name}: {real_op_list}") elif args.mode == 'run_test': import conformance as cf - cf.ConformanceTest.run(args.fname, args.model_name.lower(), args.filter_dtype) + cf.ConformanceTest.run(args.fname, args.model_name.lower(), args.filter_dtype, args.failure_debug_level) write_report() elif args.mode == 'utest': call = "python3 -m pytest -vx tests" diff --git a/DIOPI-TEST/scripts/build_impl.sh b/DIOPI-TEST/scripts/build_impl.sh new file mode 100644 index 000000000..a2e34d40b --- /dev/null +++ b/DIOPI-TEST/scripts/build_impl.sh @@ -0,0 +1,55 @@ +# !/bin/bash +set -e + +case $1 in + cuda) + (rm -rf build && mkdir build && cd build \ + && cmake .. -DCUDA_ARCH_AUTO=ON -DIMPL_OPT=CUDA && make -j4) \ + || exit -1;; + torch) + (rm -rf build && mkdir build && cd build \ + && cmake .. -DIMPL_OPT=TORCH \ + -DCMAKE_PREFIX_PATH=`python -c 'import torch;print(torch.utils.cmake_prefix_path)'` \ + && make -j4) \ + || exit -1;; + torch_dyload) + (rm -rf build && mkdir build && cd build \ + && cmake .. -DIMPL_OPT=TORCH -DDYLOAD=ON \ + -DCMAKE_PREFIX_PATH=`python -c 'import torch;print(torch.utils.cmake_prefix_path)'` \ + && make -j4) \ + || exit -1;; + torch_no_runtime) + (rm -rf build && mkdir build && cd build \ + && cmake .. -DIMPL_OPT=TORCH -DRUNTIME=OFF \ + -DCMAKE_PREFIX_PATH=`python -c 'import torch;print(torch.utils.cmake_prefix_path)'` \ + && make -j4) \ + || exit -1;; + camb_pytorch) + (rm -rf build && mkdir build && cd build \ + && cmake .. -DIMPL_OPT=camb_pytorch && make -j4) \ + || exit -1;; + camb_pytorch_no_runtime) + (rm -rf build && mkdir build && cd build \ + && cmake .. -DIMPL_OPT=camb_pytorch -DRUNTIME=OFF && make -j4) \ + || exit -1;; + camb) + (rm -rf build && mkdir build && cd build \ + && cmake .. -DIMPL_OPT=CAMB && make -j4) \ + || exit -1;; + camb_no_runtime) + (rm -rf build && mkdir build && cd build \ + && cmake .. -DIMPL_OPT=CAMB -DRUNTIME=OFF && make -j4) \ + || exit -1;; + hip_pytorch) + (rm -rf build && mkdir build && cd build \ + &&cmake .. -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DIMPL_OPT=TORCH -DHIP=ON && make -j4) \ + || exit -1;; + mmcv_ext) + (cd third_party/mmcv_diopi && rm -rf build && mkdir build \ + && MMCV_WITH_DIOPI=1 MMCV_WITH_OPS=1 python setup.py build_ext -i) \ + || exit -1;; + *) + echo -e "[ERROR] Incorrect compilation option:" $1; + +esac +exit 0 diff --git a/DIOPI-TEST/scripts/ci_script.sh b/DIOPI-TEST/scripts/ci_script.sh index 8dc8dfef2..4dd6ca404 100644 --- a/DIOPI-TEST/scripts/ci_script.sh +++ b/DIOPI-TEST/scripts/ci_script.sh @@ -4,12 +4,12 @@ set -e case $1 in py-lint) (echo "py-lint" && flake8 --ignore=E501,F841 python/conformance/diopi_functions.py \ - && flake8 --ignore=E501,F401 --exclude=python/conformance/diopi_functions.py,scripts/cpplint.py,impl/,python/conformance/model_config/ \ + && flake8 --ignore=E501,F401 --exclude=python/conformance/diopi_functions.py,scripts/cpplint.py,impl/,third_party/,python/conformance/model_config/ \ && flake8 --ignore=E501,F401 python/conformance/model_config/process_config.py python/conformance/model_config/__init__.py ) \ || exit -1;; cpp-lint) # for other cpplint version, maybe -whitespace/indent is needed to check impl - (echo "cpp-lint" && python scripts/cpplint.py --linelength=160 \ + (echo "cpp-lint" && python scripts/cpplint.py --exclude=impl/third_party/ --linelength=160 \ --filter=-build/c++11,-legal/copyright,-build/include_subdir,-runtime/references,-runtime/printf,-runtime/int,-build/namespace \ --recursive impl/ \ && python scripts/cpplint.py --linelength=240 --filter=-build/header_guard --recursive diopirt/ ) \ diff --git a/DIOPI-TEST/scripts/test_mmcv_ext.sh b/DIOPI-TEST/scripts/test_mmcv_ext.sh new file mode 100755 index 000000000..99a71da8e --- /dev/null +++ b/DIOPI-TEST/scripts/test_mmcv_ext.sh @@ -0,0 +1,38 @@ +# !/bin/bash +set -e + +if [ $# -ne 1 ] +then + echo Usage: test_mmcv_ext.sh DEVICE + exit 1 +fi + +DEVICE=${1} + +if [[ $DEVICE == "CUDA" ]]; then + MMCV_TEST_LIST=(test_active_rotated_filter.py \ + test_assign_score_withk.py \ + test_bbox.py \ + test_deform_roi_pool.py \ + test_knn.py \ + test_convex_iou.py \ + test_min_area_polygons.py \ + test_prroi_pool.py \ + test_chamfer_distance.py \ + test_border_align.py + ) +elif [[ $DEVICE == "MLU" ]]; then + MMCV_TEST_LIST=() +else + echo DEVICE $DEVICE not supported! + exit 1 +fi + +cd third_party/mmcv_diopi +export PYTHONPATH=${PWD}:$PYTHONPATH +cd tests/test_ops + +for elem in ${MMCV_TEST_LIST[@]} +do + python -m pytest $elem +done diff --git a/DIOPI-TEST/third_party/mmcv_diopi b/DIOPI-TEST/third_party/mmcv_diopi new file mode 160000 index 000000000..6b1892966 --- /dev/null +++ b/DIOPI-TEST/third_party/mmcv_diopi @@ -0,0 +1 @@ +Subproject commit 6b18929667739f5e5d0bfa8542691b2e3bf8144f