@@ -299,27 +299,6 @@ absl::StatusOr<PoolNdInputsOwner> FillAndCheckPoolNdInputs(
299299 return PoolNdInputsOwner{kernel_size, stride, padding};
300300}
301301
302- // Resizes and / or checks whether a list is of the given size. The list is only
303- // resized if its size is 1. If it's empty, it's replaced with the provided
304- // default first.
305- std::vector<int64_t > CheckIntList (absl::Span<const int64_t > list, size_t length,
306- const std::string& name,
307- std::vector<int64_t > def = {}) {
308- std::vector<int64_t > result;
309- if (list.empty ()) {
310- result = std::move (def);
311- } else {
312- result = torch::lazy::ToVector<int64_t >(list);
313- }
314- if (result.size () == 1 && length > 1 ) {
315- result.resize (length, result[0 ]);
316- return result;
317- }
318- XLA_CHECK_EQ (result.size (), length)
319- << " Invalid length for the '" << name << " ' attribute" ;
320- return result;
321- }
322-
323302// Returns a 1-D shape for batch norm weight or bias based on the input shape.
324303xla::Shape BatchNormFeaturesShape (const XLATensorPtr& input) {
325304 xla::PrimitiveType input_element_type =
@@ -666,6 +645,75 @@ absl::Status CheckUniformRangeIsValid(double from, double to) {
666645 return absl::OkStatus ();
667646}
668647
648+ absl::Status CheckCustomCallNonEmptyInputs (
649+ const std::vector<absl_nonnull XLATensorPtr>& inputs,
650+ const std::string& target) {
651+ if (inputs.empty ()) {
652+ return XLA_ERROR_WITH_LOCATION (absl::InvalidArgumentError (absl::StrCat (
653+ " custom_call(" , target, " ): expected at least 1 input tensor." )));
654+ }
655+ return absl::OkStatus ();
656+ }
657+
658+ absl::Status CheckCustomCallOutputPropertiesSize (
659+ const std::vector<std::vector<int64_t >>& output_shapes,
660+ const std::vector<at::ScalarType>& output_dtypes,
661+ const std::string& target) {
662+ if (output_shapes.size () != output_dtypes.size ()) {
663+ return XLA_ERROR_WITH_LOCATION (absl::InvalidArgumentError (absl::StrCat (
664+ " custom_call(" , target,
665+ " ): expected the given output shapes (size=" , output_shapes.size (),
666+ " ) to be of the same size as the given output dtypes (size=" ,
667+ output_dtypes.size (), " )." )));
668+ }
669+ return absl::OkStatus ();
670+ }
671+
672+ template <class F >
673+ absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> CustomCallImpl (
674+ const std::vector<absl_nonnull XLATensorPtr>& inputs,
675+ const std::string& target,
676+ const std::vector<std::vector<int64_t >>& output_shapes,
677+ const std::vector<at::ScalarType>& output_dtypes, F&& make_node) {
678+ XLA_RETURN_IF_ERROR (CheckCustomCallNonEmptyInputs (inputs, target));
679+ XLA_RETURN_IF_ERROR (CheckCustomCallOutputPropertiesSize (
680+ output_shapes, output_dtypes, target));
681+
682+ const auto & first = inputs.front ();
683+ auto device = first->GetDevice ();
684+ auto output_range = c10::irange (output_shapes.size ());
685+
686+ // `values`: vector with Lazy IR of `inputs`.
687+ std::vector<torch::lazy::Value> values (inputs.size ());
688+ std::transform (
689+ inputs.begin (), inputs.end (), values.begin (),
690+ [](const XLATensorPtr& tensor) { return tensor->GetIrValue (); });
691+
692+ // `output_xla_shapes`: `xla::Shape` instances created from `output_shapes`
693+ // and `output_dtypes`.
694+ std::vector<xla::Shape> output_xla_shapes (output_shapes.size ());
695+ std::transform (output_range.begin (), output_range.end (),
696+ output_xla_shapes.begin (), [&](std::size_t i) {
697+ return xla::ShapeUtil::MakeShape (
698+ MakeXlaPrimitiveType (output_dtypes[i], &device),
699+ output_shapes[i]);
700+ });
701+
702+ auto node = make_node (values, output_xla_shapes);
703+
704+ // `outputs`: `XLATensorPtr` instances created from the `i`-th output of
705+ // the `node` Lazy IR `Node`.
706+ std::vector<XLATensorPtr> outputs (output_shapes.size ());
707+ std::transform (output_range.begin (), output_range.end (), outputs.begin (),
708+ [&](std::size_t i) {
709+ return first->CreateFrom (torch::lazy::Value (node, i),
710+ output_dtypes[i],
711+ /* delay_eager_execution=*/ true );
712+ });
713+
714+ return outputs;
715+ }
716+
669717} // namespace
670718
671719// ////////////////////////////////////////////////////////////////////////////
@@ -886,40 +934,26 @@ std::pair<XLATensorPtr, torch::lazy::Value> collective_permute(
886934 torch::lazy::Value (node, 1 )};
887935}
888936
889- std::vector<XLATensorPtr> custom_call (
890- const std::vector<XLATensorPtr>& inputs, const std::string& target,
937+ absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> custom_call (
938+ const std::vector<absl_nonnull XLATensorPtr>& inputs,
939+ const std::string& target,
891940 const std::vector<std::vector<int64_t >>& output_shapes,
892941 const std::vector<at::ScalarType>& output_dtypes, bool has_side_effect,
893942 const std::string& backend_config, const int api_version,
894943 const std::unordered_map<std::string, std::string>& frontend_attributes) {
895- XLA_CHECK (inputs.size () > 0 ) << " inputs are empty" ;
896-
897- std::vector<torch::lazy::Value> values;
898- values.reserve (inputs.size ());
899- for (const auto & input : inputs) {
900- values.push_back (input->GetIrValue ());
901- }
902-
903- XLA_CHECK_EQ (output_shapes.size (), output_dtypes.size ());
904- std::vector<xla::Shape> output_xla_shapes;
905- output_xla_shapes.reserve (output_shapes.size ());
906- for (size_t i = 0 ; i < output_shapes.size (); ++i) {
907- output_xla_shapes.push_back (xla::ShapeUtil::MakeShape (
908- MakeXlaPrimitiveType (output_dtypes[i], &(inputs[0 ]->GetDevice ())),
909- output_shapes[i]));
910- }
911-
912- auto node = torch_xla::MakeNode<CustomCall>(
913- values, target, xla::ShapeUtil::MakeTupleShape (output_xla_shapes),
914- has_side_effect, backend_config, api_version, frontend_attributes);
944+ XLA_ASSIGN_OR_RETURN (
945+ std::vector<absl_nonnull XLATensorPtr> outputs,
946+ CustomCallImpl (inputs, target, output_shapes, output_dtypes,
947+ /* make_node= */
948+ [&](const std::vector<torch::lazy::Value>& values,
949+ const std::vector<xla::Shape>& output_xla_shapes) {
950+ return torch_xla::MakeNode<CustomCall>(
951+ values, target,
952+ xla::ShapeUtil::MakeTupleShape (output_xla_shapes),
953+ has_side_effect, backend_config, api_version,
954+ frontend_attributes);
955+ }));
915956
916- std::vector<XLATensorPtr> outputs;
917- outputs.reserve (output_shapes.size ());
918- for (size_t i = 0 ; i < output_shapes.size (); ++i) {
919- outputs.push_back (inputs[0 ]->CreateFrom (torch::lazy::Value (node, i),
920- output_dtypes[i],
921- /* delay_eager_execution=*/ true ));
922- }
923957 XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get ();
924958 if (graph_executor->UseEagerMode ()) {
925959 // Execute the HLO that will run the `customcall` and in one graph
@@ -954,37 +988,23 @@ void custom_sharding_(
954988 input->SetShardingSpec (*sharding_spec);
955989}
956990
957- std::vector<XLATensorPtr> tpu_custom_call (
958- const std::vector<XLATensorPtr>& inputs, const std::string& payload,
991+ absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> tpu_custom_call (
992+ const std::vector<absl_nonnull XLATensorPtr>& inputs,
993+ const std::string& payload,
959994 const std::vector<std::vector<int64_t >>& output_shapes,
960995 const std::vector<at::ScalarType>& output_dtypes) {
961- XLA_CHECK (inputs.size () > 0 ) << " inputs are empty" ;
962-
963- std::vector<torch::lazy::Value> values;
964- values.reserve (inputs.size ());
965- for (const auto & input : inputs) {
966- values.push_back (input->GetIrValue ());
967- }
968-
969- XLA_CHECK_EQ (output_shapes.size (), output_dtypes.size ());
970- std::vector<xla::Shape> output_xla_shapes;
971- output_xla_shapes.reserve (output_shapes.size ());
972- for (size_t i = 0 ; i < output_shapes.size (); ++i) {
973- output_xla_shapes.push_back (xla::ShapeUtil::MakeShape (
974- MakeXlaPrimitiveType (output_dtypes[i], &(inputs[0 ]->GetDevice ())),
975- output_shapes[i]));
976- }
977-
978- auto node = torch_xla::MakeNode<TpuCustomCall>(
979- values, xla::ShapeUtil::MakeTupleShape (output_xla_shapes), payload);
996+ XLA_ASSIGN_OR_RETURN (
997+ std::vector<absl_nonnull XLATensorPtr> outputs,
998+ CustomCallImpl (inputs, payload, output_shapes, output_dtypes,
999+ /* make_node= */
1000+ [&](const std::vector<torch::lazy::Value>& values,
1001+ const std::vector<xla::Shape>& output_xla_shapes) {
1002+ return torch_xla::MakeNode<TpuCustomCall>(
1003+ values,
1004+ xla::ShapeUtil::MakeTupleShape (output_xla_shapes),
1005+ payload);
1006+ }));
9801007
981- std::vector<XLATensorPtr> outputs;
982- outputs.reserve (output_shapes.size ());
983- for (size_t i = 0 ; i < output_shapes.size (); ++i) {
984- outputs.push_back (inputs[0 ]->CreateFrom (torch::lazy::Value (node, i),
985- output_dtypes[i],
986- /* delay_eager_execution=*/ true ));
987- }
9881008 XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get ();
9891009 if (graph_executor->UseEagerMode ()) {
9901010 // Execute the HLO that will run the `custom` and in one hlo
0 commit comments