@@ -299,27 +299,6 @@ absl::StatusOr<PoolNdInputsOwner> FillAndCheckPoolNdInputs(
299299 return PoolNdInputsOwner{kernel_size, stride, padding};
300300}
301301
302- // Resizes and / or checks whether a list is of the given size. The list is only
303- // resized if its size is 1. If it's empty, it's replaced with the provided
304- // default first.
305- std::vector<int64_t > CheckIntList (absl::Span<const int64_t > list, size_t length,
306- const std::string& name,
307- std::vector<int64_t > def = {}) {
308- std::vector<int64_t > result;
309- if (list.empty ()) {
310- result = std::move (def);
311- } else {
312- result = torch::lazy::ToVector<int64_t >(list);
313- }
314- if (result.size () == 1 && length > 1 ) {
315- result.resize (length, result[0 ]);
316- return result;
317- }
318- XLA_CHECK_EQ (result.size (), length)
319- << " Invalid length for the '" << name << " ' attribute" ;
320- return result;
321- }
322-
323302// Returns a 1-D shape for batch norm weight or bias based on the input shape.
324303xla::Shape BatchNormFeaturesShape (const XLATensorPtr& input) {
325304 xla::PrimitiveType input_element_type =
@@ -666,6 +645,92 @@ absl::Status CheckUniformRangeIsValid(double from, double to) {
666645 return absl::OkStatus ();
667646}
668647
648+ // This check is used for both `custom_call()` and `tpu_custom_call()`.
649+ //
650+ // The `target` parameter is `std::nullopt` whenever it's being called from
651+ // a `tpu_custom_call()` context.
652+ absl::Status CheckCustomCallNonEmptyInputs (
653+ const std::vector<absl_nonnull XLATensorPtr>& inputs,
654+ const std::optional<std::string>& target) {
655+ if (inputs.empty ()) {
656+ std::string op = target.has_value ()
657+ ? absl::StrCat (" custom_call(" , *target, " )" )
658+ : " tpu_custom_call()" ;
659+ return XLA_ERROR_WITH_LOCATION (absl::InvalidArgumentError (
660+ absl::StrCat (op, " : expected at least 1 input tensor." )));
661+ }
662+ return absl::OkStatus ();
663+ }
664+
665+ // This check is used for both `custom_call()` and `tpu_custom_call()`.
666+ //
667+ // The `target` parameter is `std::nullopt` whenever it's being called from
668+ // a `tpu_custom_call()` context.
669+ absl::Status CheckCustomCallOutputPropertiesSize (
670+ const std::vector<std::vector<int64_t >>& output_shapes,
671+ const std::vector<at::ScalarType>& output_dtypes,
672+ const std::optional<std::string>& target) {
673+ if (output_shapes.size () != output_dtypes.size ()) {
674+ std::string op = target.has_value ()
675+ ? absl::StrCat (" custom_call(" , *target, " )" )
676+ : " tpu_custom_call()" ;
677+ return XLA_ERROR_WITH_LOCATION (absl::InvalidArgumentError (absl::StrCat (
678+ op, " : expected the given output shapes (size=" , output_shapes.size (),
679+ " ) to be of the same size as the given output dtypes (size=" ,
680+ output_dtypes.size (), " )." )));
681+ }
682+ return absl::OkStatus ();
683+ }
684+
685+ // This check is used for both `custom_call()` and `tpu_custom_call()`.
686+ //
687+ // The `target` parameter is `std::nullopt` whenever it's being called from
688+ // a `tpu_custom_call()` context.
689+ template <class F >
690+ absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> CustomCallImpl (
691+ const std::vector<absl_nonnull XLATensorPtr>& inputs,
692+ const std::optional<std::string>& target,
693+ const std::vector<std::vector<int64_t >>& output_shapes,
694+ const std::vector<at::ScalarType>& output_dtypes, F&& make_node) {
695+ XLA_RETURN_IF_ERROR (CheckCustomCallNonEmptyInputs (inputs, target));
696+ XLA_RETURN_IF_ERROR (CheckCustomCallOutputPropertiesSize (
697+ output_shapes, output_dtypes, target));
698+
699+ const auto & first = inputs.front ();
700+ auto device = first->GetDevice ();
701+ auto output_range = c10::irange (output_shapes.size ());
702+
703+ // `values`: vector with Lazy IR of `inputs`.
704+ std::vector<torch::lazy::Value> values (inputs.size ());
705+ std::transform (
706+ inputs.begin (), inputs.end (), values.begin (),
707+ [](const XLATensorPtr& tensor) { return tensor->GetIrValue (); });
708+
709+ // `output_xla_shapes`: `xla::Shape` instances created from `output_shapes`
710+ // and `output_dtypes`.
711+ std::vector<xla::Shape> output_xla_shapes (output_shapes.size ());
712+ std::transform (output_range.begin (), output_range.end (),
713+ output_xla_shapes.begin (), [&](std::size_t i) {
714+ return xla::ShapeUtil::MakeShape (
715+ MakeXlaPrimitiveType (output_dtypes[i], &device),
716+ output_shapes[i]);
717+ });
718+
719+ auto node = make_node (values, output_xla_shapes);
720+
721+ // `outputs`: `XLATensorPtr` instances created from the `i`-th output of
722+ // the `node` Lazy IR `Node`.
723+ std::vector<XLATensorPtr> outputs (output_shapes.size ());
724+ std::transform (output_range.begin (), output_range.end (), outputs.begin (),
725+ [&](std::size_t i) {
726+ return first->CreateFrom (torch::lazy::Value (node, i),
727+ output_dtypes[i],
728+ /* delay_eager_execution=*/ true );
729+ });
730+
731+ return outputs;
732+ }
733+
669734} // namespace
670735
671736// ////////////////////////////////////////////////////////////////////////////
@@ -886,40 +951,26 @@ std::pair<XLATensorPtr, torch::lazy::Value> collective_permute(
886951 torch::lazy::Value (node, 1 )};
887952}
888953
889- std::vector<XLATensorPtr> custom_call (
890- const std::vector<XLATensorPtr>& inputs, const std::string& target,
954+ absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> custom_call (
955+ const std::vector<absl_nonnull XLATensorPtr>& inputs,
956+ const std::string& target,
891957 const std::vector<std::vector<int64_t >>& output_shapes,
892958 const std::vector<at::ScalarType>& output_dtypes, bool has_side_effect,
893959 const std::string& backend_config, const int api_version,
894960 const std::unordered_map<std::string, std::string>& frontend_attributes) {
895- XLA_CHECK (inputs.size () > 0 ) << " inputs are empty" ;
896-
897- std::vector<torch::lazy::Value> values;
898- values.reserve (inputs.size ());
899- for (const auto & input : inputs) {
900- values.push_back (input->GetIrValue ());
901- }
902-
903- XLA_CHECK_EQ (output_shapes.size (), output_dtypes.size ());
904- std::vector<xla::Shape> output_xla_shapes;
905- output_xla_shapes.reserve (output_shapes.size ());
906- for (size_t i = 0 ; i < output_shapes.size (); ++i) {
907- output_xla_shapes.push_back (xla::ShapeUtil::MakeShape (
908- MakeXlaPrimitiveType (output_dtypes[i], &(inputs[0 ]->GetDevice ())),
909- output_shapes[i]));
910- }
911-
912- auto node = torch_xla::MakeNode<CustomCall>(
913- values, target, xla::ShapeUtil::MakeTupleShape (output_xla_shapes),
914- has_side_effect, backend_config, api_version, frontend_attributes);
961+ XLA_ASSIGN_OR_RETURN (
962+ std::vector<absl_nonnull XLATensorPtr> outputs,
963+ CustomCallImpl (inputs, target, output_shapes, output_dtypes,
964+ /* make_node= */
965+ [&](const std::vector<torch::lazy::Value>& values,
966+ const std::vector<xla::Shape>& output_xla_shapes) {
967+ return torch_xla::MakeNode<CustomCall>(
968+ values, target,
969+ xla::ShapeUtil::MakeTupleShape (output_xla_shapes),
970+ has_side_effect, backend_config, api_version,
971+ frontend_attributes);
972+ }));
915973
916- std::vector<XLATensorPtr> outputs;
917- outputs.reserve (output_shapes.size ());
918- for (size_t i = 0 ; i < output_shapes.size (); ++i) {
919- outputs.push_back (inputs[0 ]->CreateFrom (torch::lazy::Value (node, i),
920- output_dtypes[i],
921- /* delay_eager_execution=*/ true ));
922- }
923974 XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get ();
924975 if (graph_executor->UseEagerMode ()) {
925976 // Execute the HLO that will run the `customcall` and in one graph
@@ -954,37 +1005,23 @@ void custom_sharding_(
9541005 input->SetShardingSpec (*sharding_spec);
9551006}
9561007
957- std::vector<XLATensorPtr> tpu_custom_call (
958- const std::vector<XLATensorPtr>& inputs, const std::string& payload,
1008+ absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> tpu_custom_call (
1009+ const std::vector<absl_nonnull XLATensorPtr>& inputs,
1010+ const std::string& payload,
9591011 const std::vector<std::vector<int64_t >>& output_shapes,
9601012 const std::vector<at::ScalarType>& output_dtypes) {
961- XLA_CHECK (inputs.size () > 0 ) << " inputs are empty" ;
962-
963- std::vector<torch::lazy::Value> values;
964- values.reserve (inputs.size ());
965- for (const auto & input : inputs) {
966- values.push_back (input->GetIrValue ());
967- }
968-
969- XLA_CHECK_EQ (output_shapes.size (), output_dtypes.size ());
970- std::vector<xla::Shape> output_xla_shapes;
971- output_xla_shapes.reserve (output_shapes.size ());
972- for (size_t i = 0 ; i < output_shapes.size (); ++i) {
973- output_xla_shapes.push_back (xla::ShapeUtil::MakeShape (
974- MakeXlaPrimitiveType (output_dtypes[i], &(inputs[0 ]->GetDevice ())),
975- output_shapes[i]));
976- }
977-
978- auto node = torch_xla::MakeNode<TpuCustomCall>(
979- values, xla::ShapeUtil::MakeTupleShape (output_xla_shapes), payload);
1013+ XLA_ASSIGN_OR_RETURN (
1014+ std::vector<absl_nonnull XLATensorPtr> outputs,
1015+ CustomCallImpl (
1016+ inputs, /* target= */ std::nullopt , output_shapes, output_dtypes,
1017+ /* make_node= */
1018+ [&](const std::vector<torch::lazy::Value>& values,
1019+ const std::vector<xla::Shape>& output_xla_shapes) {
1020+ return torch_xla::MakeNode<TpuCustomCall>(
1021+ values, xla::ShapeUtil::MakeTupleShape (output_xla_shapes),
1022+ payload);
1023+ }));
9801024
981- std::vector<XLATensorPtr> outputs;
982- outputs.reserve (output_shapes.size ());
983- for (size_t i = 0 ; i < output_shapes.size (); ++i) {
984- outputs.push_back (inputs[0 ]->CreateFrom (torch::lazy::Value (node, i),
985- output_dtypes[i],
986- /* delay_eager_execution=*/ true ));
987- }
9881025 XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get ();
9891026 if (graph_executor->UseEagerMode ()) {
9901027 // Execute the HLO that will run the `custom` and in one hlo
0 commit comments