diff --git a/onnxruntime/core/providers/cpu/tensor/scatter.cc b/onnxruntime/core/providers/cpu/tensor/scatter.cc index 5b4f3ace730e7..2c047676cc7ee 100644 --- a/onnxruntime/core/providers/cpu/tensor/scatter.cc +++ b/onnxruntime/core/providers/cpu/tensor/scatter.cc @@ -5,6 +5,9 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" #include "core/providers/common.h" +#ifdef ENABLE_TRAINING +#include "orttraining/training_ops/cpu/tensor/gather_elements_grad_impl.h" +#endif namespace onnxruntime { @@ -50,8 +53,15 @@ ONNX_CPU_OPERATOR_KERNEL( .TypeConstraint("Tind", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), Scatter); -template -Status CopyScatterData(const Tensor* data_input, const Tensor* indices_input, const Tensor* updates_input, +template +struct Func_Assignment { + void operator()(T* a, const T* b) const { + *a = *b; + } +}; + +template +Status CopyScatterData(const FuncT& func, const Tensor* data_input, const Tensor* indices_input, const Tensor* updates_input, const int64_t axis, Tensor* data_output) { const TensorShape& input_data_shape = data_input->Shape(); const Tin* indices_data_raw = indices_input->template Data(); @@ -157,7 +167,7 @@ Status CopyScatterData(const Tensor* data_input, const Tensor* indices_input, co } } - dst_base[dst_offset] = update_data[index]; + func(dst_base + dst_offset, update_data + index); if (++index == num_indices) { break; @@ -181,12 +191,12 @@ Status CopyScatterData(const Tensor* data_input, const Tensor* indices_input, co template inline Status CopyInt32Index(Args&&... args) { - return CopyScatterData(std::forward(args)...); + return CopyScatterData(Func_Assignment(), std::forward(args)...); } template inline Status CopyInt64Index(Args&&... args) { - return CopyScatterData(std::forward(args)...); + return CopyScatterData(Func_Assignment(), std::forward(args)...); } Status Scatter::Compute(OpKernelContext* context) const { @@ -245,4 +255,39 @@ Status Scatter::Compute(OpKernelContext* context) const { return status; } +#ifdef ENABLE_TRAINING + +namespace contrib { + +template +struct Func_Add { + void operator()(T* a, const T* b) const { + *a = *a + *b; + } +}; + +template +Status GatherElementsGradImpl(const Tensor* indices_input, const Tensor* updates_input, + const int64_t axis, Tensor* data_output) { + return CopyScatterData(Func_Add(), data_output, indices_input, updates_input, axis, data_output); +} + +#define GATHER_ELEMENTS_GRAD_IMPL_SPECIALIZED(Tin, Tdata) \ + template Status GatherElementsGradImpl( \ + const Tensor* indices_input, \ + const Tensor* updates_input, \ + const int64_t axis, \ + Tensor* data_output) + +#define GATHER_ELEMENTS_GRAD_IMPL_TDATA_SPECIALIZED(Tdata) \ + GATHER_ELEMENTS_GRAD_IMPL_SPECIALIZED(int32_t, Tdata); \ + GATHER_ELEMENTS_GRAD_IMPL_SPECIALIZED(int64_t, Tdata); + +GATHER_ELEMENTS_GRAD_IMPL_TDATA_SPECIALIZED(float) +GATHER_ELEMENTS_GRAD_IMPL_TDATA_SPECIALIZED(double) + +} // namespace contrib + +#endif + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_elements_impl.cu b/onnxruntime/core/providers/cuda/tensor/scatter_elements_impl.cu index da185a257a57d..25f96169fffd6 100755 --- a/onnxruntime/core/providers/cuda/tensor/scatter_elements_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/scatter_elements_impl.cu @@ -111,12 +111,13 @@ static int CompactInputIndicesDims( eff_indices_dims.push_back(indices_dims[axis]); int new_axis = (int)(eff_input_dims.size()); if (axis > 0) { - if (could_continue_merge) { + if (!could_continue_merge) { eff_input_dims.push_back(1); eff_indices_dims.push_back(1); + could_continue_merge = true; } int i = axis - 1; - for (; i >= 0 && could_continue_merge; --i) { + for (; i >= 0; --i) { if (input_dims[i] == indices_dims[i]) { eff_input_dims.back() *= input_dims[i]; eff_indices_dims.back() *= indices_dims[i]; @@ -130,7 +131,7 @@ static int CompactInputIndicesDims( eff_indices_dims.pop_back(); } if (!could_continue_merge) { - for (; i >= 0 && could_continue_merge; --i) { + for (; i >= 0; --i) { eff_input_dims.push_back(input_dims[i]); eff_indices_dims.push_back(indices_dims[i]); } diff --git a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc index 5cc5afca49f09..cb95e773334cd 100644 --- a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc @@ -80,6 +80,26 @@ TEST(Scatter, ThreeDimsWithAxis_0) { scatter_three_dim_with_axis_0("ScatterElements", 11); } +static void scatter_three_dim_with_axis_negative_2(const char* op_name, int op_version) { + OpTester test(op_name, op_version); + test.AddAttribute("axis", -2); + + test.AddInput("data", {2, 2, 2}, + {1, 2, 3, 4, 5, 6, 7, 8}); + test.AddInput("indices", {2, 1, 2}, + {0, 1, 1, 0}); + test.AddInput("updates", {2, 1, 2}, + {11, 12, 13, 14}); + test.AddOutput("y", {2, 2, 2}, + {11, 2, 3, 12, 5, 14, 13, 8}); + test.Run(); +} + +TEST(Scatter, ThreeDimsWithAxisNegative_2) { + scatter_three_dim_with_axis_negative_2("Scatter", 9); + scatter_three_dim_with_axis_negative_2("ScatterElements", 11); +} + static void scatter_three_dim_with_axis_2(const char* op_name, int op_version) { OpTester test(op_name, op_version); test.AddAttribute("axis", 2); diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index a11de95f3bfaa..1a04100a9e66c 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1470,5 +1470,19 @@ IMPLEMENT_GRADIENT_BUILDER(GetFlattenGradient) { }; } +IMPLEMENT_GRADIENT_BUILDER(GetTopKGradient) { + // TopK's default axis is -1, which is different from GatherElements. + auto attributes = SrcNodeAttributes(); + auto axis = utils::HasInt(attributes.at("axis")) ? attributes.at("axis").i() : -1; + return std::vector{ + NodeDef("Shape", + {I(0)}, + {IA("x_shape")}), + NodeDef(OpDef{"GatherElementsGrad", kMSDomain, 1}, + {GO(0), IA("x_shape"), O(1)}, + {GI(0)}, + {MakeAttribute("axis", axis)})}; +} + } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index c6643335c3d30..3c6dc7bf5f2f5 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -67,6 +67,7 @@ DECLARE_GRADIENT_BUILDER(GetRecvGradient) DECLARE_GRADIENT_BUILDER(GetExpandGradient) DECLARE_GRADIENT_BUILDER(GetExpGradient) DECLARE_GRADIENT_BUILDER(GetFlattenGradient) +DECLARE_GRADIENT_BUILDER(GetTopKGradient) } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index 3261d3cec56f2..f6f5218bc60a7 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -98,6 +98,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("Expand", GetExpandGradient); REGISTER_GRADIENT_BUILDER("Exp", GetExpGradient); REGISTER_GRADIENT_BUILDER("Flatten", GetFlattenGradient); + REGISTER_GRADIENT_BUILDER("TopK", GetTopKGradient); }; } // namespace training diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index b51ac8067b42a..1afe5e85c5c09 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -1707,57 +1707,6 @@ TEST(GradientCheckerTest, GatherNDGrad_unique_float_data) { } } -TEST(GradientCheckerTest, GatherElementsGradWithDuplicateUpdate) { - float max_error; - GradientChecker gradient_checker; - OpDef op_def{"GatherElements", kOnnxDomain, 11}; - - TensorInfo data_info({3, 3}, true); - TensorInfo indice_info({2, 3}, false, nullptr, DataTypeImpl::GetTensorType()); - std::vector> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 2, 0, 2, 0, 0}}; - - TensorInfo y_info({2, 3}, true); - int64_t axis = 0; - - gradient_checker.ComputeGradientError(op_def, {data_info, indice_info}, {y_info}, &max_error, x_datas, - {MakeAttribute("axis", axis)}); - EXPECT_IS_TINY(max_error); -} - -TEST(GradientCheckerTest, GatherElementsGradWithoutDuplicateUpdate) { - float max_error; - GradientChecker gradient_checker; - OpDef op_def{"GatherElements", kOnnxDomain, 11}; - - TensorInfo data_info({3, 3}, true); - TensorInfo indice_info({2, 3}, false, nullptr, DataTypeImpl::GetTensorType()); - std::vector> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 1, 1, 2, 2, 2}}; - - TensorInfo y_info({2, 3}, true); - int64_t axis = 0; - - gradient_checker.ComputeGradientError(op_def, {data_info, indice_info}, {y_info}, &max_error, x_datas, - {MakeAttribute("axis", axis)}); - EXPECT_IS_TINY(max_error); -} - -TEST(GradientCheckerTest, GatherElementsGradAxisWithDuplicateUpdate) { - float max_error; - GradientChecker gradient_checker; - OpDef op_def{"GatherElements", kOnnxDomain, 11}; - - TensorInfo data_info({3, 3}, true); - TensorInfo indice_info({2, 3}, false, nullptr, DataTypeImpl::GetTensorType()); - std::vector> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 1, 1, 1, 1, 1}}; - - TensorInfo y_info({2, 3}, true); - int64_t axis = 1; - - gradient_checker.ComputeGradientError(op_def, {data_info, indice_info}, {y_info}, &max_error, x_datas, - {MakeAttribute("axis", axis)}); - EXPECT_IS_TINY(max_error); -} - TEST(GradientCheckerTest, LayerNormGrad) { GradientChecker gradient_checker; { @@ -2008,6 +1957,104 @@ TEST(GradientCheckerTest, ExpandGrad) { } } +TEST(GradientCheckerTest, GatherElementsGrad) { + float max_error; + GradientChecker gradient_checker; + OpDef op_def{"GatherElements", kOnnxDomain, 11}; + + { + // GatherElementsGradWithDuplicateUpdate + TensorInfo data_info({3, 3}, true); + TensorInfo indice_info({2, 3}, false, nullptr, DataTypeImpl::GetTensorType()); + std::vector> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 2, 0, 2, 0, 0}}; + + TensorInfo y_info({2, 3}, true); + int64_t axis = 0; + + gradient_checker.ComputeGradientError(op_def, {data_info, indice_info}, {y_info}, &max_error, x_datas, + {MakeAttribute("axis", axis)}); + EXPECT_IS_TINY(max_error); + } + + { + // GatherElementsGradWithoutDuplicateUpdate + TensorInfo data_info({3, 3}, true); + TensorInfo indice_info({2, 3}, false, nullptr, DataTypeImpl::GetTensorType()); + std::vector> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 1, 1, 2, 2, 2}}; + + TensorInfo y_info({2, 3}, true); + int64_t axis = 0; + + gradient_checker.ComputeGradientError(op_def, {data_info, indice_info}, {y_info}, &max_error, x_datas, + {MakeAttribute("axis", axis)}); + EXPECT_IS_TINY(max_error); + } + + { + // GatherElementsGradAxisWithDuplicateUpdate + TensorInfo data_info({3, 3}, true); + TensorInfo indice_info({2, 3}, false, nullptr, DataTypeImpl::GetTensorType()); + std::vector> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 1, 1, 1, 1, 1}}; + + TensorInfo y_info({2, 3}, true); + int64_t axis = 1; + + gradient_checker.ComputeGradientError(op_def, {data_info, indice_info}, {y_info}, &max_error, x_datas, + {MakeAttribute("axis", axis)}); + EXPECT_IS_TINY(max_error); + } + + { + // GatherElementsGradWithAxisInMiddle + TensorInfo data_info({2, 2, 2}, true); + TensorInfo indice_info({2, 1, 2}, false, nullptr, DataTypeImpl::GetTensorType()); + std::vector> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8}, {1, 1, 1, 1}}; + + TensorInfo y_info({2, 1, 2}, true); + int64_t axis = 1; + + gradient_checker.ComputeGradientError(op_def, {data_info, indice_info}, {y_info}, &max_error, x_datas, + {MakeAttribute("axis", axis)}); + EXPECT_IS_TINY(max_error); + } +} + +TEST(GradientCheckerTest, TopKGrad) { + float max_error; + GradientChecker gradient_checker; + OpDef op_def{"TopK", kOnnxDomain, 11}; + + { + TensorInfo x_info({2, 2, 2}, true); + TensorInfo k_info({1}, false, nullptr, DataTypeImpl::GetTensorType()); + std::vector> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8}, {1}}; + TensorInfo y1_info({2, 2, 1}, true); + TensorInfo y2_info({2, 2, 1}, false, nullptr, DataTypeImpl::GetTensorType()); + gradient_checker.ComputeGradientError(op_def, {x_info, k_info}, {y1_info, y2_info}, &max_error, x_datas, {}, true, true); + EXPECT_IS_TINY(max_error); + } + + { + TensorInfo x_info({2, 2, 2}, true); + TensorInfo k_info({1}, false, nullptr, DataTypeImpl::GetTensorType()); + std::vector> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8}, {1}}; + TensorInfo y1_info({2, 1, 2}, true); + TensorInfo y2_info({2, 1, 2}, false, nullptr, DataTypeImpl::GetTensorType()); + gradient_checker.ComputeGradientError(op_def, {x_info, k_info}, {y1_info, y2_info}, &max_error, x_datas, {MakeAttribute("axis", int64_t(-2))}, true, true); + EXPECT_IS_TINY(max_error); + } + + { + TensorInfo x_info({3, 3}, true); + TensorInfo k_info({1}, false, nullptr, DataTypeImpl::GetTensorType()); + std::vector> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8, 9}, {2}}; + TensorInfo y1_info({3, 2}, true); + TensorInfo y2_info({3, 2}, false, nullptr, DataTypeImpl::GetTensorType()); + gradient_checker.ComputeGradientError(op_def, {x_info, k_info}, {y1_info, y2_info}, &max_error, x_datas, {}, true, true); + EXPECT_IS_TINY(max_error); + } +} + } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc index e18ad4bd4de02..0cd521eadbf7a 100644 --- a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc @@ -40,6 +40,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, LogSo class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, AveragePoolGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MaxPoolGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherGrad); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherElementsGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GeluGrad); // REVIEW(mzs): ConstEigenVectorArrayMap.cast, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, // REVIEW(mzs): ConstEigenVectorArrayMap.cast(), + DataTypeImpl::GetTensorType()}) + .TypeConstraint("Tind", std::vector{ + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + GatherElementsGrad); + +#define TYPED_GRAD_FUNCTION_CALL(T) \ + if (T_type == DataTypeImpl::GetType()) { \ + if (Tind_type == DataTypeImpl::GetType()) { \ + return GatherElementsGradImpl(indices_tensor, dY, axis, dX); \ + } \ + if (Tind_type == DataTypeImpl::GetType()) { \ + return GatherElementsGradImpl(indices_tensor, dY, axis, dX); \ + } \ + } + +Status GatherElementsGrad::Compute(OpKernelContext* context) const { + const auto* dY = context->Input(0); + const Tensor* shape = context->Input(1); + const TensorShape data_shape(shape->template Data(), shape->Shape().Size()); + + const int axis = static_cast(HandleNegativeAxis(axis_, data_shape.NumDimensions())); + + const auto* indices_tensor = context->Input(2); + + const auto& indices_dims = indices_tensor->Shape().GetDims(); + const auto& dY_dims = dY->Shape().GetDims(); + if (indices_dims.size() != dY_dims.size()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Indices and dY must have the same rank"); + } + + for (size_t i = 0; i < indices_dims.size(); ++i) { + if (indices_dims[i] != dY_dims[i]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Indices vs dY dimensions differs at position=", i, + " ", indices_dims[i], " vs ", dY_dims[i]); + } + } + + // According to the spec the rank of ind/upd shall be the same as output(data) + // and we also want to make sure that the dimensions of the of the ind/upd do not + // exceed that of the output + const auto& output_dims = data_shape.GetDims(); + if (output_dims.size() != indices_dims.size()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Indices must have the same rank as Output. Indices rank=", + indices_dims.size(), ". Output rank=", output_dims.size()); + } + + for (size_t i = 0; i < output_dims.size(); ++i) { + if (output_dims[i] < indices_dims[i]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Indices dim=", indices_dims[i], " at pos=", i, + " is greater than Output dim=", output_dims[i]); + } + } + + Tensor* dX = context->Output(0, data_shape); + ORT_ENFORCE(dX); + memset(dX->MutableDataRaw(), 0, dX->SizeInBytes()); + + MLDataType T_type = dY->DataType(); + MLDataType Tind_type = indices_tensor->DataType(); + TYPED_GRAD_FUNCTION_CALL(float); + TYPED_GRAD_FUNCTION_CALL(double); + + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Type for T or Tind not supported yet in GatherElementsGrad."); +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/tensor/gather_elements_grad.h b/orttraining/orttraining/training_ops/cpu/tensor/gather_elements_grad.h new file mode 100644 index 0000000000000..4ba184c9b4054 --- /dev/null +++ b/orttraining/orttraining/training_ops/cpu/tensor/gather_elements_grad.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { +namespace contrib { + +class GatherElementsGrad final : public OpKernel { + public: + GatherElementsGrad(const OpKernelInfo& info) : OpKernel(info) { + info.GetAttrOrDefault("axis", &axis_, static_cast(0)); + } + + Status Compute(OpKernelContext* context) const override; + + private: + int64_t axis_; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/tensor/gather_elements_grad_impl.h b/orttraining/orttraining/training_ops/cpu/tensor/gather_elements_grad_impl.h new file mode 100644 index 0000000000000..0e9dd85e3921b --- /dev/null +++ b/orttraining/orttraining/training_ops/cpu/tensor/gather_elements_grad_impl.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { +namespace contrib { + +template +Status GatherElementsGradImpl(const Tensor* indices_input, + const Tensor* updates_input, + const int64_t axis, + Tensor* data_output); + +} // namespace cuda +} // namespace onnxruntime