Skip to content

Commit 9dddb3d

Browse files
committed
Fix error messages.
1 parent 8e6eab1 commit 9dddb3d

File tree

3 files changed

+84
-37
lines changed

3 files changed

+84
-37
lines changed

test/stablehlo/test_stablehlo_custom_call.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -180,27 +180,6 @@ def get_graph(gm: torch.fx.GraphModule, _):
180180
self.assertNotIn("place_to_host", bw.code)
181181
self.assertNotIn("place_to_device", bw.code)
182182

183-
def test_stablehlo_custom_call_tracing_error(self):
184-
185-
def test_no_input():
186-
stablehlo_custom_call(tuple(), "custom_op_target", [[1]], [torch.int8])
187-
188-
self.assertExpectedRaisesInline(
189-
exc_type=RuntimeError,
190-
callable=test_no_input,
191-
expect="""custom_call(custom_op_target): expected at least 1 input tensor."""
192-
)
193-
194-
def test_output_properties_mismatch():
195-
input = torch.rand(10, device=torch_xla.device())
196-
stablehlo_custom_call((input,), "custom_op_target", [[1], [1]], [torch.int8])
197-
198-
self.assertExpectedRaisesInline(
199-
exc_type=RuntimeError,
200-
callable=test_output_properties_mismatch,
201-
expect="""custom_call(custom_op_target): expected the given output shapes (size=2) to be of the same size as the given output dtypes (size=1)."""
202-
)
203-
204183

205184
if __name__ == "__main__":
206185

test/test_ops_error_message.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Callable
12
import expecttest
23
import os
34
import torch
@@ -357,6 +358,56 @@ def gen_test_fn(kernel_size=[2, 2, 2], stride=[], padding=[0]):
357358
expect="""avg_pool3d(): expected argument padding [1, 2] (size: 2) to have size of 3."""
358359
)
359360

361+
def _get_custom_call_properties(self, mode):
362+
match mode:
363+
case "tpu":
364+
return (torch_xla._XLAC._xla_tpu_custom_call, "", [])
365+
case "stablehlo":
366+
return (torch_xla._XLAC._xla_custom_call, "custom_op_target",
367+
[False, "", 0, {}])
368+
369+
self.fail(f"expected `mode` ({mode}) to be either of ['tpu', 'stablehlo'].")
370+
371+
def _gen_custom_call_no_input(self, mode):
372+
lib_custom_call, payload, args = self._get_custom_call_properties(
373+
mode) # type: ignore[attr-defined]
374+
return lambda: lib_custom_call([], payload, [[1]], [torch.int8], *args)
375+
376+
def _gen_custom_call_output_properties_size_mismatch(self, mode):
377+
lib_custom_call, payload, args = self._get_custom_call_properties(
378+
mode) # type: ignore[attr-defined]
379+
input = torch.rand(10, device=torch_xla.device())
380+
return lambda: lib_custom_call(
381+
(input,), payload, [[1], [1]], [torch.int8], *args)
382+
383+
def test_stablehlo_custom_call(self):
384+
385+
self.assertExpectedRaisesInline(
386+
exc_type=RuntimeError,
387+
callable=self._gen_custom_call_no_input("stablehlo"),
388+
expect="""custom_call(custom_op_target): expected at least 1 input tensor."""
389+
)
390+
391+
self.assertExpectedRaisesInline(
392+
exc_type=RuntimeError,
393+
callable=self._gen_custom_call_output_properties_size_mismatch(
394+
"stablehlo"),
395+
expect="""custom_call(custom_op_target): expected the given output shapes (size=2) to be of the same size as the given output dtypes (size=1)."""
396+
)
397+
398+
def test_tpu_custom_call(self):
399+
400+
self.assertExpectedRaisesInline(
401+
exc_type=RuntimeError,
402+
callable=self._gen_custom_call_no_input("tpu"),
403+
expect="""tpu_custom_call(): expected at least 1 input tensor.""")
404+
405+
self.assertExpectedRaisesInline(
406+
exc_type=RuntimeError,
407+
callable=self._gen_custom_call_output_properties_size_mismatch("tpu"),
408+
expect="""tpu_custom_call(): expected the given output shapes (size=2) to be of the same size as the given output dtypes (size=1)."""
409+
)
410+
360411

361412
if __name__ == "__main__":
362413
unittest.main()

torch_xla/csrc/tensor_methods.cpp

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -645,34 +645,51 @@ absl::Status CheckUniformRangeIsValid(double from, double to) {
645645
return absl::OkStatus();
646646
}
647647

648+
// This check is used for both `custom_call()` and `tpu_custom_call()`.
649+
//
650+
// The `target` parameter is `std::nullopt` whenever it's being called from
651+
// a `tpu_custom_call()` context.
648652
absl::Status CheckCustomCallNonEmptyInputs(
649653
const std::vector<absl_nonnull XLATensorPtr>& inputs,
650-
const std::string& target) {
654+
const std::optional<std::string>& target) {
651655
if (inputs.empty()) {
652-
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
653-
"custom_call(", target, "): expected at least 1 input tensor.")));
656+
std::string op = target.has_value()
657+
? absl::StrCat("custom_call(", *target, ")")
658+
: "tpu_custom_call";
659+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(
660+
absl::StrCat(op, ": expected at least 1 input tensor.")));
654661
}
655662
return absl::OkStatus();
656663
}
657664

665+
// This check is used for both `custom_call()` and `tpu_custom_call()`.
666+
//
667+
// The `target` parameter is `std::nullopt` whenever it's being called from
668+
// a `tpu_custom_call()` context.
658669
absl::Status CheckCustomCallOutputPropertiesSize(
659670
const std::vector<std::vector<int64_t>>& output_shapes,
660671
const std::vector<at::ScalarType>& output_dtypes,
661-
const std::string& target) {
672+
const std::optional<std::string>& target) {
662673
if (output_shapes.size() != output_dtypes.size()) {
674+
std::string op = target.has_value()
675+
? absl::StrCat("custom_call(", *target, ")")
676+
: "tpu_custom_call()";
663677
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
664-
"custom_call(", target,
665-
"): expected the given output shapes (size=", output_shapes.size(),
678+
op, ": expected the given output shapes (size=", output_shapes.size(),
666679
") to be of the same size as the given output dtypes (size=",
667680
output_dtypes.size(), ").")));
668681
}
669682
return absl::OkStatus();
670683
}
671684

685+
// This check is used for both `custom_call()` and `tpu_custom_call()`.
686+
//
687+
// The `target` parameter is `std::nullopt` whenever it's being called from
688+
// a `tpu_custom_call()` context.
672689
template <class F>
673690
absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> CustomCallImpl(
674691
const std::vector<absl_nonnull XLATensorPtr>& inputs,
675-
const std::string& target,
692+
const std::optional<std::string>& target,
676693
const std::vector<std::vector<int64_t>>& output_shapes,
677694
const std::vector<at::ScalarType>& output_dtypes, F&& make_node) {
678695
XLA_RETURN_IF_ERROR(CheckCustomCallNonEmptyInputs(inputs, target));
@@ -995,15 +1012,15 @@ absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> tpu_custom_call(
9951012
const std::vector<at::ScalarType>& output_dtypes) {
9961013
XLA_ASSIGN_OR_RETURN(
9971014
std::vector<absl_nonnull XLATensorPtr> outputs,
998-
CustomCallImpl(inputs, payload, output_shapes, output_dtypes,
999-
/* make_node= */
1000-
[&](const std::vector<torch::lazy::Value>& values,
1001-
const std::vector<xla::Shape>& output_xla_shapes) {
1002-
return torch_xla::MakeNode<TpuCustomCall>(
1003-
values,
1004-
xla::ShapeUtil::MakeTupleShape(output_xla_shapes),
1005-
payload);
1006-
}));
1015+
CustomCallImpl(
1016+
inputs, /* target= */ std::nullopt, output_shapes, output_dtypes,
1017+
/* make_node= */
1018+
[&](const std::vector<torch::lazy::Value>& values,
1019+
const std::vector<xla::Shape>& output_xla_shapes) {
1020+
return torch_xla::MakeNode<TpuCustomCall>(
1021+
values, xla::ShapeUtil::MakeTupleShape(output_xla_shapes),
1022+
payload);
1023+
}));
10071024

10081025
XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get();
10091026
if (graph_executor->UseEagerMode()) {

0 commit comments

Comments
 (0)