Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add foreach_mul foreach_norm, foreach_unscale in cuda #1318

Merged
merged 7 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions diopi_test/python/configs/diopi_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4058,6 +4058,44 @@
# ],
# ),
# ),

'pointwise_binary_foreach_op': dict(
name=["_foreach_mul","_foreach_add"],
interface=["torch"],
para=dict(
scalar=[1.0, 5, 2.0, -1.2, 3, 10, 8, -0.5, 0, -2],
),
tensor_para=dict(
args=[
{
"ins": ["self"],
"shape": ((), (10,), (10, 2, 5), (20,), (10, 5, 1), (20, 3, 4, 5), (20, 2, 3, 4, 5),
(0,), (0, 10), (5, 0, 9)),
"gen_fn": 'Genfunc.randn',
"dtype": [np.float32, np.float16, np.float64],
"gen_policy": 'gen_tensor_list',
"gen_num_range": [1, 5]
},
],
),
),

'foreach_norm': dict(
name=['_foreach_norm'],
interface=['torch'],
tensor_para=dict(
args=[
{
"ins": ["self"],
"shape": ((256, 512, 1, 1),(8, 1, 4),(256, 64, 1, 1),(10, 1, 4),(256, 128, 1, 1),(16, 1, 4),(256, 256, 1, 1),(3, 1, 4)),
"dtype": [np.float32, np.float64, np.float16],
"gen_fn": 'Genfunc.randn',
"gen_policy": 'gen_tensor_list',
"gen_num_range": [1, 5]
},
],
),
),

'tril': dict(
name=["tril"],
Expand Down
3 changes: 2 additions & 1 deletion diopi_test/python/configs/model_config/generate_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@
'bitwise_or', 'sigmoid', 'erf', 'matmul', 'addcmul', 'std',
'arange', 'log2', 'sign', 'eq', 'nonzero', 'triangular_solve',
'ne', 'mul', 'linspace', 'index_fill', 'atan', 'le', 'sgn',
'logical_and', 'permute', 'div', 'log10', 'roll', 'ge', 'lt', 'any'],
'logical_and', 'permute', 'div', 'log10', 'roll', 'ge', 'lt', 'any',
'_foreach_add', '_foreach_mul', '_foreach_norm'],
'torch.nn.functional': ['conv2d', 'batch_norm'],
'torch.Tensor': ['fill_', 'repeat', 'unfold', 'copy_', 'expand'],
'CustomizedTest': ['linalgqr', 'adadelta', 'cast_np', 'batch_norm_elemt',
Expand Down
62 changes: 62 additions & 0 deletions diopi_test/python/conformance/diopi_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1643,6 +1643,68 @@ def clip_grad_norm_(tensors, max_norm, norm_type=2.0, error_if_nonfinite=False):

return out.value

def _foreach_add(self, scalar):
ctx = self[0].context()
num_tensors = len(self)
func = check_function("diopiForeachaddScalar")
input_tensors = list([TensorP(input) for input in self])
out_tensorV = list([Tensor(self[i].size(),self[i].get_dtype()) for i in range(num_tensors)])
out_tensors = list([TensorP(out_tensor) for out_tensor in out_tensorV])
if isinstance(scalar, Tensor):
other = scalar
else:
other = Scalar(scalar)
ret = func(
ctx,
out_tensors,
input_tensors,
num_tensors,
other
)
check_returncode(ret)

return out_tensorV

def _foreach_mul(self, scalar):
ctx = self[0].context()
num_tensors = len(self)
func = check_function("diopiForeachmulScalar")
input_tensors = list([TensorP(input) for input in self])
out_tensorV = list([Tensor(self[i].size(),self[i].get_dtype()) for i in range(num_tensors)])
out_tensors = list([TensorP(out_tensor) for out_tensor in out_tensorV])
if isinstance(scalar, Tensor):
other = scalar
else:
other = Scalar(scalar)
ret = func(
ctx,
out_tensors,
input_tensors,
num_tensors,
other
)
check_returncode(ret)

return out_tensorV

def _foreach_norm(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个前面的下划线删除了,其他同理。

ctx = self[0].context()
num_tensors = len(self)
func = check_function("diopiForeachnormScalar")
input_tensors = list([TensorP(input) for input in self])
out_tensorV = list([Tensor([],self[i].get_dtype()) for i in range(num_tensors)])
out_tensors = list([TensorP(out_tensor) for out_tensor in out_tensorV])
other = Scalar(2)
ret = func(
ctx,
out_tensors,
input_tensors,
num_tensors,
other
)
check_returncode(ret)

return out_tensorV

def batch_norm(
input,
Expand Down
94 changes: 94 additions & 0 deletions impl/torch/functions/functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1165,6 +1165,28 @@ diopiError_t diopiAddInpScalar(diopiContextHandle_t ctx, diopiTensorHandle_t inp
return diopiSuccess;
}

diopiError_t diopiForeachaddScalar(diopiContextHandle_t ctx, diopiTensorHandle_t* outs, diopiConstTensorHandle_t* inputs, int64_t inputSize,
const diopiScalar_t* other) {
impl::aten::setCurStream(ctx);
DIOPI_IMPL_BUILD_ATEN_LIST(atInputs, inputs, inputSize)
auto atOther = impl::aten::buildAtScalar(other);
auto tempOut = CALL_ATEN_CUDA_FUNC(_foreach_add, atInputs, atOther);
for (int i = 0; i < inputSize; i++) {
impl::aten::updateATen2Tensor(ctx, tempOut[i], outs[i]);
}

return diopiSuccess;
}

diopiError_t diopiForeachaddInpScalar(diopiContextHandle_t ctx, diopiTensorHandle_t* inputs, int64_t inputSize, const diopiScalar_t* other) {
impl::aten::setCurStream(ctx);
DIOPI_IMPL_BUILD_ATEN_LIST(atInputs, inputs, inputSize)
auto atOther = impl::aten::buildAtScalar(other);
CALL_ATEN_CUDA_FUNC(_foreach_add_, atInputs, atOther);

return diopiSuccess;
}

diopiError_t diopiSub(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t other,
const diopiScalar_t* alpha) {
impl::aten::setCurStream(ctx);
Expand Down Expand Up @@ -1247,6 +1269,64 @@ diopiError_t diopiMulInpScalar(diopiContextHandle_t ctx, diopiTensorHandle_t inp
return diopiSuccess;
}

diopiError_t diopiForeachmulInpScalar(diopiContextHandle_t ctx, diopiTensorHandle_t* inputs, int64_t inputSize, const diopiScalar_t* other) {
impl::aten::setCurStream(ctx);
DIOPI_IMPL_BUILD_ATEN_LIST(atInputs, inputs, inputSize)
auto atOther = impl::aten::buildAtScalar(other);
CALL_ATEN_CUDA_FUNC(_foreach_mul_, atInputs, atOther);

return diopiSuccess;
}

diopiError_t diopiForeachmulScalar(diopiContextHandle_t ctx, diopiTensorHandle_t* outs, diopiConstTensorHandle_t* inputs, int64_t inputSize,
const diopiScalar_t* other) {
DIOPI_CHECK_PTR(outs);
impl::aten::setCurStream(ctx);
DIOPI_IMPL_BUILD_ATEN_LIST(atInputs, inputs, inputSize)
auto atOther = impl::aten::buildAtScalar(other);
auto tempOut = CALL_ATEN_CUDA_FUNC(_foreach_mul, atInputs, atOther);
for (int i = 0; i < inputSize; i++) {
impl::aten::updateATen2Tensor(ctx, tempOut[i], outs[i]);
}

return diopiSuccess;
}

diopiError_t diopiForeachmulInpTensor(diopiContextHandle_t ctx, diopiTensorHandle_t* inputs, int64_t inputSize, const diopiConstTensorHandle_t other) {
impl::aten::setCurStream(ctx);
DIOPI_IMPL_BUILD_ATEN_LIST(atInputs, inputs, inputSize)
auto atOther = impl::aten::buildATen(other);
CALL_ATEN_CUDA_FUNC(_foreach_mul_, atInputs, atOther);

return diopiSuccess;
}

diopiError_t diopiForeachmulTensor(diopiContextHandle_t ctx, diopiTensorHandle_t* outs, diopiConstTensorHandle_t* inputs, int64_t inputSize,
const diopiConstTensorHandle_t other) {
DIOPI_CHECK_PTR(outs);
impl::aten::setCurStream(ctx);
DIOPI_IMPL_BUILD_ATEN_LIST(atInputs, inputs, inputSize)
DIOPI_IMPL_BUILD_ATEN_LIST(atOuts, outs, inputSize)
auto atOther = impl::aten::buildATen(other);
auto tempOut = CALL_ATEN_CUDA_FUNC(_foreach_mul, atInputs, atOther);
for (int i = 0; i < inputSize; i++) {
impl::aten::updateATen2Tensor(ctx, tempOut[i], outs[i]);
}

return diopiSuccess;
}

diopiError_t diopiAmpForeachNonFiniteCheckAndUnscaleInp(diopiContextHandle_t ctx, diopiTensorHandle_t* scaled_grads, int64_t num_scaled_grads,
diopiTensorHandle_t found_inf, diopiConstTensorHandle_t inv_scale) {
impl::aten::setCurStream(ctx);
DIOPI_IMPL_BUILD_ATEN_LIST(atScaledGrads, scaled_grads, num_scaled_grads)
auto atFoundInf = impl::aten::buildATen(found_inf);
auto atInvScale = impl::aten::buildATen(inv_scale);
CALL_ATEN_CUDA_FUNC(_amp_foreach_non_finite_check_and_unscale_, atScaledGrads, atFoundInf, atInvScale);

return diopiSuccess;
}

diopiError_t diopiGe(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t other) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
Expand Down Expand Up @@ -3230,6 +3310,20 @@ diopiError_t diopiNorm(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiC
return diopiSuccess;
}

diopiError_t diopiForeachnormScalar(diopiContextHandle_t ctx, diopiTensorHandle_t* outs, diopiConstTensorHandle_t* inputs, int64_t inputSize,
const diopiScalar_t* ord) {
DIOPI_CHECK_PTR(outs);
impl::aten::setCurStream(ctx);
DIOPI_IMPL_BUILD_ATEN_LIST(atInputs, inputs, inputSize)
auto atOrd = impl::aten::buildAtScalar(ord);
auto tempOut = CALL_ATEN_CUDA_FUNC(_foreach_norm, atInputs, atOrd);
for (int i = 0; i < inputSize; i++) {
impl::aten::updateATen2Tensor(ctx, tempOut[i], outs[i]);
}

return diopiSuccess;
}

diopiError_t diopiGroupNorm(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t save_mean, diopiTensorHandle_t save_invstd,
diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t num_groups, double eps) {
impl::aten::setCurStream(ctx);
Expand Down
70 changes: 70 additions & 0 deletions proto/include/diopi/functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,25 @@ DIOPI_API diopiError_t diopiAddScalar(diopiContextHandle_t ctx, diopiTensorHandl
*/
DIOPI_API diopiError_t diopiAddInpScalar(diopiContextHandle_t ctx, diopiTensorHandle_t input, const diopiScalar_t* other, const diopiScalar_t* alpha);

/**
* @brief The diopiForeachaddScalar.
* @param[in] ctx Context environment.
* @param[out] outs the output tensor list and will be store the result tensor. type = [float64, float32, float16, int64, int32, int16, int8, uint8, bool].
* @param[in] inputs the input tensor list. type = [float64, float32, float16, int64, int32, int16, int8, uint8, bool].
* @param[in] inputSize the length of the input tensor list. type = [int64].
* @param[in] other The scalar value to be multiplied. type = [float64, float32, float16, int64, int32, int16, int8, uint8].
*/
DIOPI_API diopiError_t diopiForeachaddScalar(diopiContextHandle_t ctx, diopiTensorHandle_t* outs, diopiConstTensorHandle_t* inputs, int64_t inputSize,
const diopiScalar_t* other);

/**
* @brief The in-place version of diopiForeachaddScalar.
* @param[in] ctx Context environment.
* @param[in] input the input tensor list and will be stored result tensor. type = [float64, float32, float16, int64, int32, int16, int8, uint8, bool].
* @param[in] inputSize the length of the input tensor list. type = [int64].
* @param[in] other The scalar value to be multiplied. type = [float64, float32, float16, int64, int32, int16, int8, uint8].
*/
DIOPI_API diopiError_t diopiForeachaddInpScalar(diopiContextHandle_t ctx, diopiTensorHandle_t* inputs, int64_t inputSize, const diopiScalar_t* other);
/**
* @brief Perform subtraction operations between tensors.
* @param[in] ctx Context environment.
Expand Down Expand Up @@ -1180,6 +1199,46 @@ DIOPI_API diopiError_t diopiMulScalar(diopiContextHandle_t ctx, diopiTensorHandl
*/
DIOPI_API diopiError_t diopiMulInpScalar(diopiContextHandle_t ctx, diopiTensorHandle_t input, const diopiScalar_t* other);

/**
* @brief The diopiForeachmulScalar.
* @param[in] ctx Context environment.
* @param[out] outs the output tensor list. type = [float64, float32, float16, int64, int32, int16, int8, uint8, bool].
* @param[in] inputs the input tensor list. type = [float64, float32, float16, int64, int32, int16, int8, uint8, bool].
* @param[in] inputSize the length of the input tensor list. type = [int64].
* @param[in] other The scalar value to be multiplied. type = [float64, float32, float16, int64, int32, int16, int8, uint8].
*/
DIOPI_API diopiError_t diopiForeachmulScalar(diopiContextHandle_t ctx, diopiTensorHandle_t* outs, diopiConstTensorHandle_t* inputs, int64_t inputSize,
const diopiScalar_t* other);

/**
* @brief The in-place version of diopiForeachmulScalar.
* @param[in] ctx Context environment.
* @param[in] inputs the input tensor list and will be stored result tensor. type = [float64, float32, float16, int64, int32, int16, int8, uint8, bool].
* @param[in] inputSize the length of the input tensor list. type = [int64].
* @param[in] other The scalar value to be multiplied. type = [float64, float32, float16, int64, int32, int16, int8, uint8].
*/
DIOPI_API diopiError_t diopiForeachmulInpScalar(diopiContextHandle_t ctx, diopiTensorHandle_t* inputs, int64_t inputSize, const diopiScalar_t* other);

/*
* @brief The diopiForeachmulTensor
* @param[in] ctx Context environment.
* @param[in] puts the output tensor list and will be stored result tensor. type = [float64, float32, float16, int64, int32, int16, int8, uint8, bool].
* @param[in] inputs the input tensor list. type = [float64, float32, float16, int64, int32, int16, int8, uint8, bool].
* @param[in] inputSize the length of the input tensor list. type = [int64].
* @param[in] other The tensor to be multiplied. type = [float64, float32, float16, int64, int32, int16, int8, uint8].
*/
DIOPI_API diopiError_t diopiForeachmulTensor(diopiContextHandle_t ctx, diopiTensorHandle_t* outs, diopiConstTensorHandle_t* inputs, int64_t inputSize,
const diopiConstTensorHandle_t other);

/**
* @brief The in-place version of diopiForeachmulTensor.
* @param[in] ctx Context environment.
* @param[in] inputs the input tensor list and will be stored result tensor. type = [float64, float32, float16, int64, int32, int16, int8, uint8, bool].
* @param[in] inputSize the length of the input tensor list. type = [int64].
* @param[in] other The tensor to be multiplied. type = [float64, float32, float16, int64, int32, int16, int8, uint8].
*/
DIOPI_API diopiError_t diopiForeachmulInpTensor(diopiContextHandle_t ctx, diopiTensorHandle_t* inputs, int64_t inputSize, const diopiConstTensorHandle_t other);

/**
* @brief Divides each element of input tensor by the corresponding element in other tensor.
* @param[in] ctx Context environment.
Expand All @@ -1189,6 +1248,7 @@ DIOPI_API diopiError_t diopiMulInpScalar(diopiContextHandle_t ctx, diopiTensorHa
* the inputs are promoted to the default scalar type; trunc: truncate towards zero; floor: round down towards negative infinity for the result of the division.
* @param[out] out the output tensor. type = [float64, float32, float16, int64, int32, int16, int8, uint8, bool].
*/

DIOPI_API diopiError_t diopiDiv(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t other,
diopiRoundMode_t rounding_mode);

Expand Down Expand Up @@ -2888,6 +2948,16 @@ DIOPI_API diopiError_t diopiFlip(diopiContextHandle_t ctx, diopiTensorHandle_t o
*/
DIOPI_API diopiError_t diopiNorm(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const diopiScalar_t* p, diopiSize_t dim);

/**
* @brief Returns the matrix norm or vector norm of a given tensor list.
* @param[in] ctx Context environment.
* @param[out] outs the output tesnor list, type=[float32, float64, float16].
* @param[in] inputs the input tesnor list, type=[float32, float64, float16].
* @param[in] inputSize the input size
* @param[in] p an array, the order of norm.
*/
DIOPI_API diopiError_t diopiForeachnormScalar(diopiContextHandle_t ctx, diopiTensorHandle_t* outs, diopiConstTensorHandle_t* inputs, int64_t inputSize,
const diopiScalar_t* ord);
/**
* \brief Applies Group Normalization over a mini-batch of inputs.
* @param[in] ctx Context environment.
Expand Down
Loading