@@ -645,34 +645,51 @@ absl::Status CheckUniformRangeIsValid(double from, double to) {
645645 return absl::OkStatus ();
646646}
647647
648+ // This check is used for both `custom_call()` and `tpu_custom_call()`.
649+ //
650+ // The `target` parameter is `std::nullopt` whenever it's being called from
651+ // a `tpu_custom_call()` context.
648652absl::Status CheckCustomCallNonEmptyInputs (
649653 const std::vector<absl_nonnull XLATensorPtr>& inputs,
650- const std::string& target) {
654+ const std::optional<std:: string> & target) {
651655 if (inputs.empty ()) {
652- return XLA_ERROR_WITH_LOCATION (absl::InvalidArgumentError (absl::StrCat (
653- " custom_call(" , target, " ): expected at least 1 input tensor." )));
656+ std::string op = target.has_value ()
657+ ? absl::StrCat (" custom_call(" , *target, " )" )
658+ : " tpu_custom_call" ;
659+ return XLA_ERROR_WITH_LOCATION (absl::InvalidArgumentError (
660+ absl::StrCat (op, " : expected at least 1 input tensor." )));
654661 }
655662 return absl::OkStatus ();
656663}
657664
665+ // This check is used for both `custom_call()` and `tpu_custom_call()`.
666+ //
667+ // The `target` parameter is `std::nullopt` whenever it's being called from
668+ // a `tpu_custom_call()` context.
658669absl::Status CheckCustomCallOutputPropertiesSize (
659670 const std::vector<std::vector<int64_t >>& output_shapes,
660671 const std::vector<at::ScalarType>& output_dtypes,
661- const std::string& target) {
672+ const std::optional<std:: string> & target) {
662673 if (output_shapes.size () != output_dtypes.size ()) {
674+ std::string op = target.has_value ()
675+ ? absl::StrCat (" custom_call(" , *target, " )" )
676+ : " tpu_custom_call()" ;
663677 return XLA_ERROR_WITH_LOCATION (absl::InvalidArgumentError (absl::StrCat (
664- " custom_call(" , target,
665- " ): expected the given output shapes (size=" , output_shapes.size (),
678+ op, " : expected the given output shapes (size=" , output_shapes.size (),
666679 " ) to be of the same size as the given output dtypes (size=" ,
667680 output_dtypes.size (), " )." )));
668681 }
669682 return absl::OkStatus ();
670683}
671684
685+ // This check is used for both `custom_call()` and `tpu_custom_call()`.
686+ //
687+ // The `target` parameter is `std::nullopt` whenever it's being called from
688+ // a `tpu_custom_call()` context.
672689template <class F >
673690absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> CustomCallImpl (
674691 const std::vector<absl_nonnull XLATensorPtr>& inputs,
675- const std::string& target,
692+ const std::optional<std:: string> & target,
676693 const std::vector<std::vector<int64_t >>& output_shapes,
677694 const std::vector<at::ScalarType>& output_dtypes, F&& make_node) {
678695 XLA_RETURN_IF_ERROR (CheckCustomCallNonEmptyInputs (inputs, target));
@@ -995,15 +1012,15 @@ absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> tpu_custom_call(
9951012 const std::vector<at::ScalarType>& output_dtypes) {
9961013 XLA_ASSIGN_OR_RETURN (
9971014 std::vector<absl_nonnull XLATensorPtr> outputs,
998- CustomCallImpl (inputs, payload, output_shapes, output_dtypes,
999- /* make_node = */
1000- [&]( const std::vector<torch::lazy::Value>& values,
1001- const std::vector<xla::Shape >& output_xla_shapes) {
1002- return torch_xla::MakeNode<TpuCustomCall>(
1003- values,
1004- xla::ShapeUtil::MakeTupleShape (output_xla_shapes),
1005- payload);
1006- }));
1015+ CustomCallImpl (
1016+ inputs, /* target = */ std:: nullopt , output_shapes, output_dtypes,
1017+ /* make_node= */
1018+ [&]( const std::vector<torch::lazy::Value >& values,
1019+ const std::vector<xla::Shape>& output_xla_shapes) {
1020+ return torch_xla::MakeNode<TpuCustomCall>(
1021+ values, xla::ShapeUtil::MakeTupleShape (output_xla_shapes),
1022+ payload);
1023+ }));
10071024
10081025 XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get ();
10091026 if (graph_executor->UseEagerMode ()) {
0 commit comments