Skip to content

Commit

Permalink
add tests in diopi_tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrench-Git committed Jul 31, 2024
1 parent 57266f3 commit 981ea99
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 1 deletion.
21 changes: 21 additions & 0 deletions diopi_test/python/configs/diopi_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4058,6 +4058,27 @@
# ],
# ),
# ),

'foreach_op': dict(
name=["_foreach_mul","_foreach_add"],
interface=["torch"],
para=dict(
scalar=[1.0, 5, 2.0, -1.2, 3, 10, 8, -0.5, 0, -2],
),
tensor_para=dict(
args=[
{
"ins": ["self"],
"shape": ((), (10,), (10, 2, 5), (20,), (10, 5, 1), (20, 3, 4, 5), (20, 2, 3, 4, 5),
(0,), (0, 10), (5, 0, 9)),
"gen_fn": 'Genfunc.randn',
"dtype": [np.float32, np.float16, np.float64],
"gen_policy": 'gen_tensor_list',
"gen_num_range": [1, 5]
},
],
),
),

'tril': dict(
name=["tril"],
Expand Down
3 changes: 2 additions & 1 deletion diopi_test/python/configs/model_config/generate_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@
'bitwise_or', 'sigmoid', 'erf', 'matmul', 'addcmul', 'std',
'arange', 'log2', 'sign', 'eq', 'nonzero', 'triangular_solve',
'ne', 'mul', 'linspace', 'index_fill', 'atan', 'le', 'sgn',
'logical_and', 'permute', 'div', 'log10', 'roll', 'ge', 'lt', 'any'],
'logical_and', 'permute', 'div', 'log10', 'roll', 'ge', 'lt', 'any',
'_foreach_add','_foreach_mul'],
'torch.nn.functional': ['conv2d', 'batch_norm'],
'torch.Tensor': ['fill_', 'repeat', 'unfold', 'copy_', 'expand'],
'CustomizedTest': ['linalgqr', 'adadelta', 'cast_np', 'batch_norm_elemt',
Expand Down
43 changes: 43 additions & 0 deletions diopi_test/python/conformance/diopi_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1643,6 +1643,49 @@ def clip_grad_norm_(tensors, max_norm, norm_type=2.0, error_if_nonfinite=False):

return out.value

def _foreach_add(self, scalar):
ctx = self[0].context()
num_tensors = len(self)
func = check_function("diopiForeachaddScalar")
input_tensors = list([TensorP(input) for input in self])
out_tensorV = list([Tensor(self[i].size(),self[i].get_dtype()) for i in range(num_tensors)])
out_tensors = list([TensorP(out_tensor) for out_tensor in out_tensorV])
if isinstance(scalar, Tensor):
other = scalar
else:
other = Scalar(scalar)
ret = func(
ctx,
out_tensors,
input_tensors,
num_tensors,
other
)
check_returncode(ret)

return out_tensorV

def _foreach_mul(self, scalar):
ctx = self[0].context()
num_tensors = len(self)
func = check_function("diopiForeachmulScalar")
input_tensors = list([TensorP(input) for input in self])
out_tensorV = list([Tensor(self[i].size(),self[i].get_dtype()) for i in range(num_tensors)])
out_tensors = list([TensorP(out_tensor) for out_tensor in out_tensorV])
if isinstance(scalar, Tensor):
other = scalar
else:
other = Scalar(scalar)
ret = func(
ctx,
out_tensors,
input_tensors,
num_tensors,
other
)
check_returncode(ret)

return out_tensorV

def batch_norm(
input,
Expand Down

0 comments on commit 981ea99

Please sign in to comment.