Skip to content

Commit

Permalink
add foreach norm in diopi tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrench-Git committed Jul 31, 2024
1 parent 880ddf4 commit 9438dad
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 3 deletions.
23 changes: 22 additions & 1 deletion diopi_test/python/configs/diopi_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4059,7 +4059,7 @@
# ),
# ),

'foreach_op': dict(
'pointwise_binary_foreach_op': dict(
name=["_foreach_mul","_foreach_add"],
interface=["torch"],
para=dict(
Expand All @@ -4079,6 +4079,27 @@
],
),
),

'foreach_norm': dict(
name=['_foreach_norm'],
interface=['torch'],
para=dict(
p=[0, 2.5, float('inf'), -float('inf'), 2, -2, 1, 2, 0],
),
tensor_para=dict(
args=[
{
"ins": ["self"],
"shape": ((), (128, ), (384, 128), (256, 512, 1, 1), (384, 128), (384, 128),
(0,), (0, 12), (13, 0, 4)),
"dtype": [np.float32, np.float64, np.float16],
"gen_fn": 'Genfunc.randn',
"gen_policy": 'gen_tensor_list',
"gen_num_range": [1, 5]
},
],
),
),

'tril': dict(
name=["tril"],
Expand Down
2 changes: 1 addition & 1 deletion diopi_test/python/configs/model_config/generate_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
'arange', 'log2', 'sign', 'eq', 'nonzero', 'triangular_solve',
'ne', 'mul', 'linspace', 'index_fill', 'atan', 'le', 'sgn',
'logical_and', 'permute', 'div', 'log10', 'roll', 'ge', 'lt', 'any',
'_foreach_add','_foreach_mul'],
'_foreach_add', '_foreach_mul', '_foreach_norm'],
'torch.nn.functional': ['conv2d', 'batch_norm'],
'torch.Tensor': ['fill_', 'repeat', 'unfold', 'copy_', 'expand'],
'CustomizedTest': ['linalgqr', 'adadelta', 'cast_np', 'batch_norm_elemt',
Expand Down
22 changes: 22 additions & 0 deletions diopi_test/python/conformance/diopi_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1687,6 +1687,28 @@ def _foreach_mul(self, scalar):

return out_tensorV

def _foreach_norm(self, scalar):
ctx = self[0].context()
num_tensors = len(self)
func = check_function("diopiForeachnormScalar")
input_tensors = list([TensorP(input) for input in self])
out_tensorV = list([Tensor(self[i].size(),self[i].get_dtype()) for i in range(num_tensors)])
out_tensors = list([TensorP(out_tensor) for out_tensor in out_tensorV])
if isinstance(scalar, Tensor):
other = scalar
else:
other = Scalar(scalar)
ret = func(
ctx,
out_tensors,
input_tensors,
num_tensors,
other
)
check_returncode(ret)

return out_tensorV

def batch_norm(
input,
running_mean,
Expand Down
1 change: 0 additions & 1 deletion impl/torch/functions/functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3318,7 +3318,6 @@ diopiError_t diopiForeachnormScalar(diopiContextHandle_t ctx, diopiTensorHandle_
auto atP = impl::aten::buildAtScalar(p);
auto tempOut = CALL_ATEN_CUDA_FUNC(_foreach_norm, atInputs, atP);
for (int i = 0; i < inputSize; i++) {
//WARN NO NEED TO COPY HERE, WE NEED FASTER UPDATE HERE
impl::aten::updateATen2Tensor(ctx, tempOut[i], outs[i]);
}

Expand Down

0 comments on commit 9438dad

Please sign in to comment.