Skip to content

Commit 8ca9770

Browse files
authored
custom_call: improve error handling and error messages. (#9680)
This PR refactors the `custom_call` implementation by improving its error message, and returning a status type value. **Key Changes:** - Make `tensor_methods::{tpu,}custom_call` return `StatusOr<vector<XLATensorPtr>>` - Improve error messages and error handling - Create new `CheckCustomCallNonEmptyInputs` function for checking if there' at least 1 input tensor - Create new `CheckCustomCallOutputPropertiesSize` function for checking if the given shapes and dtypes match in size, i.e. they agree on the number of outputs - Create new `CustomCallImpl` function for implementing both `custom_call` and `tpu_custom_call`, since they did mostly the same thing - Deleted `CheckIntList` function (should have been deleted in [#9648](#9648)) - Deleted `TpuCustomCall` function, inlining it in its corresponding Python binding, which is how `custom_call` binding is currently implemented
1 parent 4e2515e commit 8ca9770

File tree

5 files changed

+192
-113
lines changed

5 files changed

+192
-113
lines changed

test/stablehlo/test_stablehlo_custom_call.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import expecttest
12
import sys
23
import re
34
import unittest
@@ -16,7 +17,7 @@
1617
m = Library("my_custom_library", "DEF")
1718

1819

19-
class StableHLOCustomCallExportTest(unittest.TestCase):
20+
class StableHLOCustomCallExportTest(expecttest.TestCase):
2021

2122
def test_single_output(self):
2223

test/test_ops_error_message.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Callable
12
import expecttest
23
import os
34
import torch
@@ -357,6 +358,56 @@ def gen_test_fn(kernel_size=[2, 2, 2], stride=[], padding=[0]):
357358
expect="""avg_pool3d(): expected argument padding [1, 2] (size: 2) to have size of 3."""
358359
)
359360

361+
def _get_custom_call_properties(self, mode):
362+
match mode:
363+
case "tpu":
364+
return (torch_xla._XLAC._xla_tpu_custom_call, "", [])
365+
case "stablehlo":
366+
return (torch_xla._XLAC._xla_custom_call, "custom_op_target",
367+
[False, "", 0, {}])
368+
369+
self.fail(f"expected `mode` ({mode}) to be either of ['tpu', 'stablehlo'].")
370+
371+
def _gen_custom_call_no_input(self, mode):
372+
lib_custom_call, payload, args = self._get_custom_call_properties(
373+
mode) # type: ignore[attr-defined]
374+
return lambda: lib_custom_call([], payload, [[1]], [torch.int8], *args)
375+
376+
def _gen_custom_call_output_properties_size_mismatch(self, mode):
377+
lib_custom_call, payload, args = self._get_custom_call_properties(
378+
mode) # type: ignore[attr-defined]
379+
input = torch.rand(10, device=torch_xla.device())
380+
return lambda: lib_custom_call(
381+
(input,), payload, [[1], [1]], [torch.int8], *args)
382+
383+
def test_stablehlo_custom_call(self):
384+
385+
self.assertExpectedRaisesInline(
386+
exc_type=RuntimeError,
387+
callable=self._gen_custom_call_no_input("stablehlo"),
388+
expect="""custom_call(custom_op_target): expected at least 1 input tensor."""
389+
)
390+
391+
self.assertExpectedRaisesInline(
392+
exc_type=RuntimeError,
393+
callable=self._gen_custom_call_output_properties_size_mismatch(
394+
"stablehlo"),
395+
expect="""custom_call(custom_op_target): expected the given output shapes (size=2) to be of the same size as the given output dtypes (size=1)."""
396+
)
397+
398+
def test_tpu_custom_call(self):
399+
400+
self.assertExpectedRaisesInline(
401+
exc_type=RuntimeError,
402+
callable=self._gen_custom_call_no_input("tpu"),
403+
expect="""tpu_custom_call(): expected at least 1 input tensor.""")
404+
405+
self.assertExpectedRaisesInline(
406+
exc_type=RuntimeError,
407+
callable=self._gen_custom_call_output_properties_size_mismatch("tpu"),
408+
expect="""tpu_custom_call(): expected the given output shapes (size=2) to be of the same size as the given output dtypes (size=1)."""
409+
)
410+
360411

361412
if __name__ == "__main__":
362413
unittest.main()

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -347,21 +347,6 @@ std::vector<std::vector<int64_t>> CreateReduceGroups(const py::list& groups) {
347347
return replica_groups;
348348
}
349349

350-
std::vector<at::Tensor> TpuCustomCall(
351-
const std::vector<at::Tensor>& inputs, const std::string& payload,
352-
const std::vector<std::vector<int64_t>>& output_shapes,
353-
const std::vector<py::object>& output_dtypes) {
354-
std::vector<at::ScalarType> dtypes;
355-
dtypes.reserve(output_dtypes.size());
356-
for (auto& dtype : output_dtypes) {
357-
dtypes.push_back(reinterpret_cast<THPDtype*>(dtype.ptr())->scalar_type);
358-
}
359-
XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> xla_inputs,
360-
bridge::GetXlaTensors(inputs));
361-
return bridge::AtenFromXlaTensors(tensor_methods::tpu_custom_call(
362-
xla_inputs, payload, output_shapes, dtypes));
363-
}
364-
365350
std::vector<std::vector<int>> ExtractXlaDotGeneralDimVectors(
366351
const py::tuple& dimension_numbers) {
367352
// Expect Python arg `dimension_numbers` to be
@@ -3116,30 +3101,33 @@ void InitXlaModuleBindings(py::module m) {
31163101
"_xla_custom_call",
31173102
[](const std::vector<at::Tensor>& inputs, const std::string& target,
31183103
const std::vector<std::vector<int64_t>>& output_shapes,
3119-
const std::vector<py::object>& output_dtypes, bool has_side_effect,
3104+
const std::vector<at::ScalarType>& output_dtypes, bool has_side_effect,
31203105
const std::string& backend_config, const int api_version,
31213106
const std::unordered_map<std::string, std::string>&
31223107
frontend_attributes) -> std::vector<at::Tensor> {
3123-
std::vector<at::ScalarType> dtypes;
3124-
dtypes.reserve(output_dtypes.size());
3125-
for (auto& dtype : output_dtypes) {
3126-
dtypes.push_back(
3127-
reinterpret_cast<THPDtype*>(dtype.ptr())->scalar_type);
3128-
}
31293108

3130-
XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> xla_inputs, bridge::GetXlaTensors(inputs));
3131-
auto xtensors = tensor_methods::custom_call(
3132-
xla_inputs, target,
3133-
output_shapes, dtypes, has_side_effect, backend_config,
3134-
api_version, frontend_attributes);
3135-
return bridge::AtenFromXlaTensors(std::move(xtensors));
3109+
XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> xla_inputs,
3110+
bridge::GetXlaTensors(inputs));
3111+
XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> xla_outputs,
3112+
tensor_methods::custom_call(
3113+
xla_inputs, target, output_shapes, output_dtypes,
3114+
has_side_effect, backend_config, api_version,
3115+
frontend_attributes));
3116+
3117+
return bridge::AtenFromXlaTensors(std::move(xla_outputs));
31363118
})
31373119
.def("_xla_tpu_custom_call",
31383120
[](const std::vector<at::Tensor>& inputs, const std::string& payload,
31393121
const std::vector<std::vector<int64_t>>& output_shapes,
3140-
const std::vector<py::object>& output_dtypes)
3122+
const std::vector<at::ScalarType>& output_dtypes)
31413123
-> std::vector<at::Tensor> {
3142-
return TpuCustomCall(inputs, payload, output_shapes, output_dtypes);
3124+
3125+
XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> xla_inputs,
3126+
bridge::GetXlaTensors(inputs));
3127+
XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> xla_outputs,
3128+
tensor_methods::tpu_custom_call(xla_inputs, payload, output_shapes, output_dtypes));
3129+
3130+
return bridge::AtenFromXlaTensors(std::move(xla_outputs));
31433131
})
31443132
.def("_xla_register_custom_call_target",
31453133
[](const std::string& fn_name, const py::capsule& function_ptr,

torch_xla/csrc/tensor_methods.cpp

Lines changed: 115 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -299,27 +299,6 @@ absl::StatusOr<PoolNdInputsOwner> FillAndCheckPoolNdInputs(
299299
return PoolNdInputsOwner{kernel_size, stride, padding};
300300
}
301301

302-
// Resizes and / or checks whether a list is of the given size. The list is only
303-
// resized if its size is 1. If it's empty, it's replaced with the provided
304-
// default first.
305-
std::vector<int64_t> CheckIntList(absl::Span<const int64_t> list, size_t length,
306-
const std::string& name,
307-
std::vector<int64_t> def = {}) {
308-
std::vector<int64_t> result;
309-
if (list.empty()) {
310-
result = std::move(def);
311-
} else {
312-
result = torch::lazy::ToVector<int64_t>(list);
313-
}
314-
if (result.size() == 1 && length > 1) {
315-
result.resize(length, result[0]);
316-
return result;
317-
}
318-
XLA_CHECK_EQ(result.size(), length)
319-
<< "Invalid length for the '" << name << "' attribute";
320-
return result;
321-
}
322-
323302
// Returns a 1-D shape for batch norm weight or bias based on the input shape.
324303
xla::Shape BatchNormFeaturesShape(const XLATensorPtr& input) {
325304
xla::PrimitiveType input_element_type =
@@ -666,6 +645,92 @@ absl::Status CheckUniformRangeIsValid(double from, double to) {
666645
return absl::OkStatus();
667646
}
668647

648+
// This check is used for both `custom_call()` and `tpu_custom_call()`.
649+
//
650+
// The `target` parameter is `std::nullopt` whenever it's being called from
651+
// a `tpu_custom_call()` context.
652+
absl::Status CheckCustomCallNonEmptyInputs(
653+
const std::vector<absl_nonnull XLATensorPtr>& inputs,
654+
const std::optional<std::string>& target) {
655+
if (inputs.empty()) {
656+
std::string op = target.has_value()
657+
? absl::StrCat("custom_call(", *target, ")")
658+
: "tpu_custom_call()";
659+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(
660+
absl::StrCat(op, ": expected at least 1 input tensor.")));
661+
}
662+
return absl::OkStatus();
663+
}
664+
665+
// This check is used for both `custom_call()` and `tpu_custom_call()`.
666+
//
667+
// The `target` parameter is `std::nullopt` whenever it's being called from
668+
// a `tpu_custom_call()` context.
669+
absl::Status CheckCustomCallOutputPropertiesSize(
670+
const std::vector<std::vector<int64_t>>& output_shapes,
671+
const std::vector<at::ScalarType>& output_dtypes,
672+
const std::optional<std::string>& target) {
673+
if (output_shapes.size() != output_dtypes.size()) {
674+
std::string op = target.has_value()
675+
? absl::StrCat("custom_call(", *target, ")")
676+
: "tpu_custom_call()";
677+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
678+
op, ": expected the given output shapes (size=", output_shapes.size(),
679+
") to be of the same size as the given output dtypes (size=",
680+
output_dtypes.size(), ").")));
681+
}
682+
return absl::OkStatus();
683+
}
684+
685+
// This check is used for both `custom_call()` and `tpu_custom_call()`.
686+
//
687+
// The `target` parameter is `std::nullopt` whenever it's being called from
688+
// a `tpu_custom_call()` context.
689+
template <class F>
690+
absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> CustomCallImpl(
691+
const std::vector<absl_nonnull XLATensorPtr>& inputs,
692+
const std::optional<std::string>& target,
693+
const std::vector<std::vector<int64_t>>& output_shapes,
694+
const std::vector<at::ScalarType>& output_dtypes, F&& make_node) {
695+
XLA_RETURN_IF_ERROR(CheckCustomCallNonEmptyInputs(inputs, target));
696+
XLA_RETURN_IF_ERROR(CheckCustomCallOutputPropertiesSize(
697+
output_shapes, output_dtypes, target));
698+
699+
const auto& first = inputs.front();
700+
auto device = first->GetDevice();
701+
auto output_range = c10::irange(output_shapes.size());
702+
703+
// `values`: vector with Lazy IR of `inputs`.
704+
std::vector<torch::lazy::Value> values(inputs.size());
705+
std::transform(
706+
inputs.begin(), inputs.end(), values.begin(),
707+
[](const XLATensorPtr& tensor) { return tensor->GetIrValue(); });
708+
709+
// `output_xla_shapes`: `xla::Shape` instances created from `output_shapes`
710+
// and `output_dtypes`.
711+
std::vector<xla::Shape> output_xla_shapes(output_shapes.size());
712+
std::transform(output_range.begin(), output_range.end(),
713+
output_xla_shapes.begin(), [&](std::size_t i) {
714+
return xla::ShapeUtil::MakeShape(
715+
MakeXlaPrimitiveType(output_dtypes[i], &device),
716+
output_shapes[i]);
717+
});
718+
719+
auto node = make_node(values, output_xla_shapes);
720+
721+
// `outputs`: `XLATensorPtr` instances created from the `i`-th output of
722+
// the `node` Lazy IR `Node`.
723+
std::vector<XLATensorPtr> outputs(output_shapes.size());
724+
std::transform(output_range.begin(), output_range.end(), outputs.begin(),
725+
[&](std::size_t i) {
726+
return first->CreateFrom(torch::lazy::Value(node, i),
727+
output_dtypes[i],
728+
/*delay_eager_execution=*/true);
729+
});
730+
731+
return outputs;
732+
}
733+
669734
} // namespace
670735

671736
//////////////////////////////////////////////////////////////////////////////
@@ -886,40 +951,26 @@ std::pair<XLATensorPtr, torch::lazy::Value> collective_permute(
886951
torch::lazy::Value(node, 1)};
887952
}
888953

889-
std::vector<XLATensorPtr> custom_call(
890-
const std::vector<XLATensorPtr>& inputs, const std::string& target,
954+
absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> custom_call(
955+
const std::vector<absl_nonnull XLATensorPtr>& inputs,
956+
const std::string& target,
891957
const std::vector<std::vector<int64_t>>& output_shapes,
892958
const std::vector<at::ScalarType>& output_dtypes, bool has_side_effect,
893959
const std::string& backend_config, const int api_version,
894960
const std::unordered_map<std::string, std::string>& frontend_attributes) {
895-
XLA_CHECK(inputs.size() > 0) << "inputs are empty";
896-
897-
std::vector<torch::lazy::Value> values;
898-
values.reserve(inputs.size());
899-
for (const auto& input : inputs) {
900-
values.push_back(input->GetIrValue());
901-
}
902-
903-
XLA_CHECK_EQ(output_shapes.size(), output_dtypes.size());
904-
std::vector<xla::Shape> output_xla_shapes;
905-
output_xla_shapes.reserve(output_shapes.size());
906-
for (size_t i = 0; i < output_shapes.size(); ++i) {
907-
output_xla_shapes.push_back(xla::ShapeUtil::MakeShape(
908-
MakeXlaPrimitiveType(output_dtypes[i], &(inputs[0]->GetDevice())),
909-
output_shapes[i]));
910-
}
911-
912-
auto node = torch_xla::MakeNode<CustomCall>(
913-
values, target, xla::ShapeUtil::MakeTupleShape(output_xla_shapes),
914-
has_side_effect, backend_config, api_version, frontend_attributes);
961+
XLA_ASSIGN_OR_RETURN(
962+
std::vector<absl_nonnull XLATensorPtr> outputs,
963+
CustomCallImpl(inputs, target, output_shapes, output_dtypes,
964+
/* make_node= */
965+
[&](const std::vector<torch::lazy::Value>& values,
966+
const std::vector<xla::Shape>& output_xla_shapes) {
967+
return torch_xla::MakeNode<CustomCall>(
968+
values, target,
969+
xla::ShapeUtil::MakeTupleShape(output_xla_shapes),
970+
has_side_effect, backend_config, api_version,
971+
frontend_attributes);
972+
}));
915973

916-
std::vector<XLATensorPtr> outputs;
917-
outputs.reserve(output_shapes.size());
918-
for (size_t i = 0; i < output_shapes.size(); ++i) {
919-
outputs.push_back(inputs[0]->CreateFrom(torch::lazy::Value(node, i),
920-
output_dtypes[i],
921-
/*delay_eager_execution=*/true));
922-
}
923974
XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get();
924975
if (graph_executor->UseEagerMode()) {
925976
// Execute the HLO that will run the `customcall` and in one graph
@@ -954,37 +1005,23 @@ void custom_sharding_(
9541005
input->SetShardingSpec(*sharding_spec);
9551006
}
9561007

957-
std::vector<XLATensorPtr> tpu_custom_call(
958-
const std::vector<XLATensorPtr>& inputs, const std::string& payload,
1008+
absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> tpu_custom_call(
1009+
const std::vector<absl_nonnull XLATensorPtr>& inputs,
1010+
const std::string& payload,
9591011
const std::vector<std::vector<int64_t>>& output_shapes,
9601012
const std::vector<at::ScalarType>& output_dtypes) {
961-
XLA_CHECK(inputs.size() > 0) << "inputs are empty";
962-
963-
std::vector<torch::lazy::Value> values;
964-
values.reserve(inputs.size());
965-
for (const auto& input : inputs) {
966-
values.push_back(input->GetIrValue());
967-
}
968-
969-
XLA_CHECK_EQ(output_shapes.size(), output_dtypes.size());
970-
std::vector<xla::Shape> output_xla_shapes;
971-
output_xla_shapes.reserve(output_shapes.size());
972-
for (size_t i = 0; i < output_shapes.size(); ++i) {
973-
output_xla_shapes.push_back(xla::ShapeUtil::MakeShape(
974-
MakeXlaPrimitiveType(output_dtypes[i], &(inputs[0]->GetDevice())),
975-
output_shapes[i]));
976-
}
977-
978-
auto node = torch_xla::MakeNode<TpuCustomCall>(
979-
values, xla::ShapeUtil::MakeTupleShape(output_xla_shapes), payload);
1013+
XLA_ASSIGN_OR_RETURN(
1014+
std::vector<absl_nonnull XLATensorPtr> outputs,
1015+
CustomCallImpl(
1016+
inputs, /* target= */ std::nullopt, output_shapes, output_dtypes,
1017+
/* make_node= */
1018+
[&](const std::vector<torch::lazy::Value>& values,
1019+
const std::vector<xla::Shape>& output_xla_shapes) {
1020+
return torch_xla::MakeNode<TpuCustomCall>(
1021+
values, xla::ShapeUtil::MakeTupleShape(output_xla_shapes),
1022+
payload);
1023+
}));
9801024

981-
std::vector<XLATensorPtr> outputs;
982-
outputs.reserve(output_shapes.size());
983-
for (size_t i = 0; i < output_shapes.size(); ++i) {
984-
outputs.push_back(inputs[0]->CreateFrom(torch::lazy::Value(node, i),
985-
output_dtypes[i],
986-
/*delay_eager_execution=*/true));
987-
}
9881025
XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get();
9891026
if (graph_executor->UseEagerMode()) {
9901027
// Execute the HLO that will run the `custom` and in one hlo

torch_xla/csrc/tensor_methods.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,9 @@ std::pair<XLATensorPtr, torch::lazy::Value> collective_permute(
9292
const XLATensorPtr& input, const torch::lazy::Value& token,
9393
std::vector<std::pair<int64_t, int64_t>> source_target_pairs);
9494

95-
std::vector<XLATensorPtr> custom_call(
96-
const std::vector<XLATensorPtr>& inputs, const std::string& target,
95+
absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> custom_call(
96+
const std::vector<absl_nonnull XLATensorPtr>& inputs,
97+
const std::string& target,
9798
const std::vector<std::vector<int64_t>>& output_shapes,
9899
const std::vector<at::ScalarType>& output_dtypes, bool has_side_effect,
99100
const std::string& backend_config, const int api_version,
@@ -104,8 +105,9 @@ void custom_sharding_(
104105
const std::shared_ptr<XLATensor::ShardingSpec>& spec,
105106
const CustomSharding::Type& type = CustomSharding::Type::kSharding);
106107

107-
std::vector<XLATensorPtr> tpu_custom_call(
108-
const std::vector<XLATensorPtr>& inputs, const std::string& payload,
108+
absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> tpu_custom_call(
109+
const std::vector<absl_nonnull XLATensorPtr>& inputs,
110+
const std::string& payload,
109111
const std::vector<std::vector<int64_t>>& output_shapes,
110112
const std::vector<at::ScalarType>& output_dtypes);
111113

0 commit comments

Comments
 (0)