Skip to content

Commit

Permalink
GatherElementsGrad CPU Kernel and TopKGrad CPU/CUDA Kernel (microsoft…
Browse files Browse the repository at this point in the history
…#5511)

* TopKGrad CPU kernel

* use Scatter for GatherElementsGrad and TopKGrad.

* rollback convgrad change.

Co-authored-by: Vincent Wang <[email protected]>
  • Loading branch information
centwang and centwang authored Oct 21, 2020
1 parent 6c2162e commit b48f596
Show file tree
Hide file tree
Showing 11 changed files with 320 additions and 59 deletions.
55 changes: 50 additions & 5 deletions onnxruntime/core/providers/cpu/tensor/scatter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/providers/common.h"
#ifdef ENABLE_TRAINING
#include "orttraining/training_ops/cpu/tensor/gather_elements_grad_impl.h"
#endif

namespace onnxruntime {

Expand Down Expand Up @@ -50,8 +53,15 @@ ONNX_CPU_OPERATOR_KERNEL(
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(), DataTypeImpl::GetTensorType<int64_t>()}),
Scatter);

template <class Tin, class Tdata>
Status CopyScatterData(const Tensor* data_input, const Tensor* indices_input, const Tensor* updates_input,
template <class T>
struct Func_Assignment {
void operator()(T* a, const T* b) const {
*a = *b;
}
};

template <class Tin, class Tdata, typename FuncT>
Status CopyScatterData(const FuncT& func, const Tensor* data_input, const Tensor* indices_input, const Tensor* updates_input,
const int64_t axis, Tensor* data_output) {
const TensorShape& input_data_shape = data_input->Shape();
const Tin* indices_data_raw = indices_input->template Data<Tin>();
Expand Down Expand Up @@ -157,7 +167,7 @@ Status CopyScatterData(const Tensor* data_input, const Tensor* indices_input, co
}
}

dst_base[dst_offset] = update_data[index];
func(dst_base + dst_offset, update_data + index);

if (++index == num_indices) {
break;
Expand All @@ -181,12 +191,12 @@ Status CopyScatterData(const Tensor* data_input, const Tensor* indices_input, co

template <class T, class... Args>
inline Status CopyInt32Index(Args&&... args) {
return CopyScatterData<int32_t, T>(std::forward<Args>(args)...);
return CopyScatterData<int32_t, T>(Func_Assignment<T>(), std::forward<Args>(args)...);
}

template <class T, class... Args>
inline Status CopyInt64Index(Args&&... args) {
return CopyScatterData<int64_t, T>(std::forward<Args>(args)...);
return CopyScatterData<int64_t, T>(Func_Assignment<T>(), std::forward<Args>(args)...);
}

Status Scatter::Compute(OpKernelContext* context) const {
Expand Down Expand Up @@ -245,4 +255,39 @@ Status Scatter::Compute(OpKernelContext* context) const {
return status;
}

#ifdef ENABLE_TRAINING

namespace contrib {

template <class T>
struct Func_Add {
void operator()(T* a, const T* b) const {
*a = *a + *b;
}
};

template <class Tin, class Tdata>
Status GatherElementsGradImpl(const Tensor* indices_input, const Tensor* updates_input,
const int64_t axis, Tensor* data_output) {
return CopyScatterData<Tin, Tdata>(Func_Add<Tdata>(), data_output, indices_input, updates_input, axis, data_output);
}

#define GATHER_ELEMENTS_GRAD_IMPL_SPECIALIZED(Tin, Tdata) \
template Status GatherElementsGradImpl<Tin, Tdata>( \
const Tensor* indices_input, \
const Tensor* updates_input, \
const int64_t axis, \
Tensor* data_output)

#define GATHER_ELEMENTS_GRAD_IMPL_TDATA_SPECIALIZED(Tdata) \
GATHER_ELEMENTS_GRAD_IMPL_SPECIALIZED(int32_t, Tdata); \
GATHER_ELEMENTS_GRAD_IMPL_SPECIALIZED(int64_t, Tdata);

GATHER_ELEMENTS_GRAD_IMPL_TDATA_SPECIALIZED(float)
GATHER_ELEMENTS_GRAD_IMPL_TDATA_SPECIALIZED(double)

} // namespace contrib

#endif

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,13 @@ static int CompactInputIndicesDims(
eff_indices_dims.push_back(indices_dims[axis]);
int new_axis = (int)(eff_input_dims.size());
if (axis > 0) {
if (could_continue_merge) {
if (!could_continue_merge) {
eff_input_dims.push_back(1);
eff_indices_dims.push_back(1);
could_continue_merge = true;
}
int i = axis - 1;
for (; i >= 0 && could_continue_merge; --i) {
for (; i >= 0; --i) {
if (input_dims[i] == indices_dims[i]) {
eff_input_dims.back() *= input_dims[i];
eff_indices_dims.back() *= indices_dims[i];
Expand All @@ -130,7 +131,7 @@ static int CompactInputIndicesDims(
eff_indices_dims.pop_back();
}
if (!could_continue_merge) {
for (; i >= 0 && could_continue_merge; --i) {
for (; i >= 0; --i) {
eff_input_dims.push_back(input_dims[i]);
eff_indices_dims.push_back(indices_dims[i]);
}
Expand Down
20 changes: 20 additions & 0 deletions onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,26 @@ TEST(Scatter, ThreeDimsWithAxis_0) {
scatter_three_dim_with_axis_0("ScatterElements", 11);
}

static void scatter_three_dim_with_axis_negative_2(const char* op_name, int op_version) {
OpTester test(op_name, op_version);
test.AddAttribute<int64_t>("axis", -2);

test.AddInput<int64_t>("data", {2, 2, 2},
{1, 2, 3, 4, 5, 6, 7, 8});
test.AddInput<int64_t>("indices", {2, 1, 2},
{0, 1, 1, 0});
test.AddInput<int64_t>("updates", {2, 1, 2},
{11, 12, 13, 14});
test.AddOutput<int64_t>("y", {2, 2, 2},
{11, 2, 3, 12, 5, 14, 13, 8});
test.Run();
}

TEST(Scatter, ThreeDimsWithAxisNegative_2) {
scatter_three_dim_with_axis_negative_2("Scatter", 9);
scatter_three_dim_with_axis_negative_2("ScatterElements", 11);
}

static void scatter_three_dim_with_axis_2(const char* op_name, int op_version) {
OpTester test(op_name, op_version);
test.AddAttribute<int64_t>("axis", 2);
Expand Down
14 changes: 14 additions & 0 deletions orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1470,5 +1470,19 @@ IMPLEMENT_GRADIENT_BUILDER(GetFlattenGradient) {
};
}

IMPLEMENT_GRADIENT_BUILDER(GetTopKGradient) {
// TopK's default axis is -1, which is different from GatherElements.
auto attributes = SrcNodeAttributes();
auto axis = utils::HasInt(attributes.at("axis")) ? attributes.at("axis").i() : -1;
return std::vector<NodeDef>{
NodeDef("Shape",
{I(0)},
{IA("x_shape")}),
NodeDef(OpDef{"GatherElementsGrad", kMSDomain, 1},
{GO(0), IA("x_shape"), O(1)},
{GI(0)},
{MakeAttribute("axis", axis)})};
}

} // namespace training
} // namespace onnxruntime
1 change: 1 addition & 0 deletions orttraining/orttraining/core/graph/gradient_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ DECLARE_GRADIENT_BUILDER(GetRecvGradient)
DECLARE_GRADIENT_BUILDER(GetExpandGradient)
DECLARE_GRADIENT_BUILDER(GetExpGradient)
DECLARE_GRADIENT_BUILDER(GetFlattenGradient)
DECLARE_GRADIENT_BUILDER(GetTopKGradient)

} // namespace training
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("Expand", GetExpandGradient);
REGISTER_GRADIENT_BUILDER("Exp", GetExpGradient);
REGISTER_GRADIENT_BUILDER("Flatten", GetFlattenGradient);
REGISTER_GRADIENT_BUILDER("TopK", GetTopKGradient);
};

} // namespace training
Expand Down
149 changes: 98 additions & 51 deletions orttraining/orttraining/test/gradient/gradient_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1707,57 +1707,6 @@ TEST(GradientCheckerTest, GatherNDGrad_unique_float_data) {
}
}

TEST(GradientCheckerTest, GatherElementsGradWithDuplicateUpdate) {
float max_error;
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{"GatherElements", kOnnxDomain, 11};

TensorInfo data_info({3, 3}, true);
TensorInfo indice_info({2, 3}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 2, 0, 2, 0, 0}};

TensorInfo y_info({2, 3}, true);
int64_t axis = 0;

gradient_checker.ComputeGradientError(op_def, {data_info, indice_info}, {y_info}, &max_error, x_datas,
{MakeAttribute("axis", axis)});
EXPECT_IS_TINY(max_error);
}

TEST(GradientCheckerTest, GatherElementsGradWithoutDuplicateUpdate) {
float max_error;
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{"GatherElements", kOnnxDomain, 11};

TensorInfo data_info({3, 3}, true);
TensorInfo indice_info({2, 3}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 1, 1, 2, 2, 2}};

TensorInfo y_info({2, 3}, true);
int64_t axis = 0;

gradient_checker.ComputeGradientError(op_def, {data_info, indice_info}, {y_info}, &max_error, x_datas,
{MakeAttribute("axis", axis)});
EXPECT_IS_TINY(max_error);
}

TEST(GradientCheckerTest, GatherElementsGradAxisWithDuplicateUpdate) {
float max_error;
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{"GatherElements", kOnnxDomain, 11};

TensorInfo data_info({3, 3}, true);
TensorInfo indice_info({2, 3}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 1, 1, 1, 1, 1}};

TensorInfo y_info({2, 3}, true);
int64_t axis = 1;

gradient_checker.ComputeGradientError(op_def, {data_info, indice_info}, {y_info}, &max_error, x_datas,
{MakeAttribute("axis", axis)});
EXPECT_IS_TINY(max_error);
}

TEST(GradientCheckerTest, LayerNormGrad) {
GradientChecker<float, float, float> gradient_checker;
{
Expand Down Expand Up @@ -2008,6 +1957,104 @@ TEST(GradientCheckerTest, ExpandGrad) {
}
}

TEST(GradientCheckerTest, GatherElementsGrad) {
float max_error;
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{"GatherElements", kOnnxDomain, 11};

{
// GatherElementsGradWithDuplicateUpdate
TensorInfo data_info({3, 3}, true);
TensorInfo indice_info({2, 3}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 2, 0, 2, 0, 0}};

TensorInfo y_info({2, 3}, true);
int64_t axis = 0;

gradient_checker.ComputeGradientError(op_def, {data_info, indice_info}, {y_info}, &max_error, x_datas,
{MakeAttribute("axis", axis)});
EXPECT_IS_TINY(max_error);
}

{
// GatherElementsGradWithoutDuplicateUpdate
TensorInfo data_info({3, 3}, true);
TensorInfo indice_info({2, 3}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 1, 1, 2, 2, 2}};

TensorInfo y_info({2, 3}, true);
int64_t axis = 0;

gradient_checker.ComputeGradientError(op_def, {data_info, indice_info}, {y_info}, &max_error, x_datas,
{MakeAttribute("axis", axis)});
EXPECT_IS_TINY(max_error);
}

{
// GatherElementsGradAxisWithDuplicateUpdate
TensorInfo data_info({3, 3}, true);
TensorInfo indice_info({2, 3}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 1, 1, 1, 1, 1}};

TensorInfo y_info({2, 3}, true);
int64_t axis = 1;

gradient_checker.ComputeGradientError(op_def, {data_info, indice_info}, {y_info}, &max_error, x_datas,
{MakeAttribute("axis", axis)});
EXPECT_IS_TINY(max_error);
}

{
// GatherElementsGradWithAxisInMiddle
TensorInfo data_info({2, 2, 2}, true);
TensorInfo indice_info({2, 1, 2}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8}, {1, 1, 1, 1}};

TensorInfo y_info({2, 1, 2}, true);
int64_t axis = 1;

gradient_checker.ComputeGradientError(op_def, {data_info, indice_info}, {y_info}, &max_error, x_datas,
{MakeAttribute("axis", axis)});
EXPECT_IS_TINY(max_error);
}
}

TEST(GradientCheckerTest, TopKGrad) {
float max_error;
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{"TopK", kOnnxDomain, 11};

{
TensorInfo x_info({2, 2, 2}, true);
TensorInfo k_info({1}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8}, {1}};
TensorInfo y1_info({2, 2, 1}, true);
TensorInfo y2_info({2, 2, 1}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
gradient_checker.ComputeGradientError(op_def, {x_info, k_info}, {y1_info, y2_info}, &max_error, x_datas, {}, true, true);
EXPECT_IS_TINY(max_error);
}

{
TensorInfo x_info({2, 2, 2}, true);
TensorInfo k_info({1}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8}, {1}};
TensorInfo y1_info({2, 1, 2}, true);
TensorInfo y2_info({2, 1, 2}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
gradient_checker.ComputeGradientError(op_def, {x_info, k_info}, {y1_info, y2_info}, &max_error, x_datas, {MakeAttribute("axis", int64_t(-2))}, true, true);
EXPECT_IS_TINY(max_error);
}

{
TensorInfo x_info({3, 3}, true);
TensorInfo k_info({1}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8, 9}, {2}};
TensorInfo y1_info({3, 2}, true);
TensorInfo y2_info({3, 2}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
gradient_checker.ComputeGradientError(op_def, {x_info, k_info}, {y1_info, y2_info}, &max_error, x_datas, {}, true, true);
EXPECT_IS_TINY(max_error);
}
}

} // namespace test
} // namespace onnxruntime

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, LogSo
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, AveragePoolGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MaxPoolGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherElementsGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GeluGrad);
// REVIEW(mzs): ConstEigenVectorArrayMap.cast<MLFLoat16) does not seem to be supported.
// However these types work on GPU implementation.
Expand Down Expand Up @@ -142,6 +143,7 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, AveragePoolGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MaxPoolGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherElementsGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GeluGrad)>,
// REVIEW(mzs): ConstEigenVectorArrayMap.cast<MLFLoat16) does not seem to be supported.
// However these types work on GPU implementation.
Expand Down
Loading

0 comments on commit b48f596

Please sign in to comment.