diff --git a/core/meta/test/testTClass.cxx b/core/meta/test/testTClass.cxx index c40f9676f16e4..535eab67bbb98 100644 --- a/core/meta/test/testTClass.cxx +++ b/core/meta/test/testTClass.cxx @@ -159,3 +159,14 @@ TEST(TClass, TemplateTemplate) // "_gnu_cxx::__common_pool_policy > >"), // 0); } + +// ROOT-10728 +TEST(TClass, CanSplitWithBaseWithCustomStreamer) +{ + gInterpreter->Declare("class CanSplitWithBaseWithCustomStreamer : public TH1D {\n" + "int a = 0;\n" + "ClassDef(CanSplitWithBaseWithCustomStreamer, 1)};"); + + auto c = TClass::GetClass("CanSplitWithBaseWithCustomStreamer"); + EXPECT_FALSE(c->CanSplit()); +} diff --git a/tmva/sofie/inc/TMVA/OperatorList.hxx b/tmva/sofie/inc/TMVA/OperatorList.hxx index 309a0fc703147..1eb72874cf15d 100644 --- a/tmva/sofie/inc/TMVA/OperatorList.hxx +++ b/tmva/sofie/inc/TMVA/OperatorList.hxx @@ -1,6 +1,7 @@ #include "TMVA/ROperator_Transpose.hxx" #include "TMVA/ROperator_Gemm.hxx" #include "TMVA/ROperator_Relu.hxx" +#include "TMVA/ROperator_Gelu.hxx" #include "TMVA/ROperator_Tanh.hxx" #include "TMVA/ROperator_LeakyRelu.hxx" #include "TMVA/ROperator_Selu.hxx" diff --git a/tmva/sofie/inc/TMVA/ROperator_Gelu.hxx b/tmva/sofie/inc/TMVA/ROperator_Gelu.hxx new file mode 100644 index 0000000000000..db10d1bb4c252 --- /dev/null +++ b/tmva/sofie/inc/TMVA/ROperator_Gelu.hxx @@ -0,0 +1,84 @@ +#ifndef TMVA_SOFIE_ROPERATOR_GELU +#define TMVA_SOFIE_ROPERATOR_GELU + +#include "TMVA/SOFIE_common.hxx" +#include "TMVA/ROperator.hxx" +#include "TMVA/RModel.hxx" + +#include +#include + +namespace TMVA{ +namespace Experimental{ +namespace SOFIE{ + +template +class ROperator_Gelu final : public ROperator +{ + +private: + + std::string fNX; + std::string fNY; + std::vector fShape; + +public: + ROperator_Gelu(){} + ROperator_Gelu(std::string nameX, std::string nameY): + fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY)){ + fInputTensorNames = { fNX }; + fOutputTensorNames = { fNY }; + } + + std::vector TypeInference(std::vector input) override { + return input; + } + + std::vector> ShapeInference(std::vector> input) override { + auto ret = input; // suggest copy to compiler + return ret; + } + + void Initialize(RModel& model) override { + if (model.CheckIfTensorAlreadyExist(fNX) == false){ + throw std::runtime_error( + "TMVA SOFIE Gelu Op Input Tensor " + fNX + " is not found in model"); + } + + fShape = model.GetDimTensorShape(fNX); + + model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShape); + if (model.Verbose()) { + std::cout << "Gelu : " << fNX << " -> " << fNY << " " + << ConvertShapeToString(fShape) << std::endl; + } + } + + std::string Generate(std::string OpName) override { + OpName = "op_" + OpName; + if (fShape.empty()) { + throw std::runtime_error( + "TMVA SOFIE Operator Gelu called to Generate without being initialized first"); + } + + std::stringstream out; + auto length = ConvertDynamicShapeToLength(fShape); + + out << "\n//------ GELU (exact, erf-based)\n"; + out << SP << "for (int id = 0; id < " << length << " ; id++){\n"; + out << SP << SP + << "tensor_" << fNY << "[id] = " + << "tensor_" << fNX << "[id] * 0.5 * " + << "(1.0 + std::erf(tensor_" << fNX << "[id] / std::sqrt(2.0)));\n"; + out << SP << "}\n"; + + return out.str(); + } + +}; + +}//SOFIE +}//Experimental +}//TMVA + +#endif //TMVA_SOFIE_ROPERATOR_GELU diff --git a/tmva/sofie_parsers/CMakeLists.txt b/tmva/sofie_parsers/CMakeLists.txt index 4ad063693fbe5..56a69f4d43e64 100644 --- a/tmva/sofie_parsers/CMakeLists.txt +++ b/tmva/sofie_parsers/CMakeLists.txt @@ -41,6 +41,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTTMVASofieParser src/ParsePool.cxx src/ParseReduce.cxx src/ParseRelu.cxx + src/ParseGelu.cxx src/ParseReshape.cxx src/ParseRNN.cxx src/ParseSelu.cxx diff --git a/tmva/sofie_parsers/src/ParseGelu.cxx b/tmva/sofie_parsers/src/ParseGelu.cxx new file mode 100644 index 0000000000000..d7c9269451de6 --- /dev/null +++ b/tmva/sofie_parsers/src/ParseGelu.cxx @@ -0,0 +1,39 @@ +#include "TMVA/RModelParser_ONNX.hxx" +#include "TMVA/ROperator_Gelu.hxx" +#include "onnx_proto3.pb.h" + +namespace TMVA { +namespace Experimental { +namespace SOFIE { + +ParserFuncSignature ParseGelu = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { + ETensorType input_type; + + auto input_name = nodeproto.input(0); + if (parser.IsRegisteredTensorType(input_name)) { + input_type = parser.GetTensorType(input_name); + } else { + throw std::runtime_error("TMVA::SOFIE ONNX Parser Gelu op has input tensor" + input_name + + " but its type is not yet registered"); + } + + std::unique_ptr op; + std::string output_name = nodeproto.output(0); + + switch (input_type) { + case ETensorType::FLOAT: op.reset(new ROperator_Gelu(input_name, output_name)); break; + default: + throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Gelu does not yet support input type " + + std::to_string(static_cast(input_type))); + } + + if (!parser.IsRegisteredTensorType(output_name)) { + parser.RegisterTensorType(output_name, input_type); + } + + return op; +}; + +} // namespace SOFIE +} // namespace Experimental +} // namespace TMVA diff --git a/tmva/sofie_parsers/src/RModelParser_ONNX.cxx b/tmva/sofie_parsers/src/RModelParser_ONNX.cxx index 7b4ade2b6bc09..e3c2bc4664626 100644 --- a/tmva/sofie_parsers/src/RModelParser_ONNX.cxx +++ b/tmva/sofie_parsers/src/RModelParser_ONNX.cxx @@ -51,6 +51,7 @@ extern ParserFuncSignature ParseReduceProd; extern ParserFuncSignature ParseBatchNormalization; extern ParserFuncSignature ParseConstant; extern ParserFuncSignature ParseTranspose; +extern ParserFuncSignature ParseGelu; extern ParserFuncSignature ParseRelu; extern ParserFuncSignature ParseTanh; extern ParserFuncSignature ParseConv; @@ -201,6 +202,7 @@ RModelParser_ONNX::RModelParser_ONNX() noexcept : fOperatorsMapImpl(std::make_un RegisterOperator("AveragePool", ParsePool); RegisterOperator("GlobalAveragePool", ParsePool); RegisterOperator("MaxPool", ParsePool); + RegisterOperator("Gelu", ParseGelu); RegisterOperator("Relu", ParseRelu); RegisterOperator("Reshape", ParseReshape); RegisterOperator("Flatten", ParseReshape);