Skip to content

Commit 8e6eab1

Browse files
committed
Improve error message and error handling for custom_call.
1 parent d291621 commit 8e6eab1

File tree

4 files changed

+145
-113
lines changed

4 files changed

+145
-113
lines changed

test/stablehlo/test_stablehlo_custom_call.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import expecttest
12
import sys
23
import re
34
import unittest
@@ -16,7 +17,7 @@
1617
m = Library("my_custom_library", "DEF")
1718

1819

19-
class StableHLOCustomCallExportTest(unittest.TestCase):
20+
class StableHLOCustomCallExportTest(expecttest.TestCase):
2021

2122
def test_single_output(self):
2223

@@ -179,6 +180,27 @@ def get_graph(gm: torch.fx.GraphModule, _):
179180
self.assertNotIn("place_to_host", bw.code)
180181
self.assertNotIn("place_to_device", bw.code)
181182

183+
def test_stablehlo_custom_call_tracing_error(self):
184+
185+
def test_no_input():
186+
stablehlo_custom_call(tuple(), "custom_op_target", [[1]], [torch.int8])
187+
188+
self.assertExpectedRaisesInline(
189+
exc_type=RuntimeError,
190+
callable=test_no_input,
191+
expect="""custom_call(custom_op_target): expected at least 1 input tensor."""
192+
)
193+
194+
def test_output_properties_mismatch():
195+
input = torch.rand(10, device=torch_xla.device())
196+
stablehlo_custom_call((input,), "custom_op_target", [[1], [1]], [torch.int8])
197+
198+
self.assertExpectedRaisesInline(
199+
exc_type=RuntimeError,
200+
callable=test_output_properties_mismatch,
201+
expect="""custom_call(custom_op_target): expected the given output shapes (size=2) to be of the same size as the given output dtypes (size=1)."""
202+
)
203+
182204

183205
if __name__ == "__main__":
184206

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -347,21 +347,6 @@ std::vector<std::vector<int64_t>> CreateReduceGroups(const py::list& groups) {
347347
return replica_groups;
348348
}
349349

350-
std::vector<at::Tensor> TpuCustomCall(
351-
const std::vector<at::Tensor>& inputs, const std::string& payload,
352-
const std::vector<std::vector<int64_t>>& output_shapes,
353-
const std::vector<py::object>& output_dtypes) {
354-
std::vector<at::ScalarType> dtypes;
355-
dtypes.reserve(output_dtypes.size());
356-
for (auto& dtype : output_dtypes) {
357-
dtypes.push_back(reinterpret_cast<THPDtype*>(dtype.ptr())->scalar_type);
358-
}
359-
XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> xla_inputs,
360-
bridge::GetXlaTensors(inputs));
361-
return bridge::AtenFromXlaTensors(tensor_methods::tpu_custom_call(
362-
xla_inputs, payload, output_shapes, dtypes));
363-
}
364-
365350
std::vector<std::vector<int>> ExtractXlaDotGeneralDimVectors(
366351
const py::tuple& dimension_numbers) {
367352
// Expect Python arg `dimension_numbers` to be
@@ -3116,30 +3101,33 @@ void InitXlaModuleBindings(py::module m) {
31163101
"_xla_custom_call",
31173102
[](const std::vector<at::Tensor>& inputs, const std::string& target,
31183103
const std::vector<std::vector<int64_t>>& output_shapes,
3119-
const std::vector<py::object>& output_dtypes, bool has_side_effect,
3104+
const std::vector<at::ScalarType>& output_dtypes, bool has_side_effect,
31203105
const std::string& backend_config, const int api_version,
31213106
const std::unordered_map<std::string, std::string>&
31223107
frontend_attributes) -> std::vector<at::Tensor> {
3123-
std::vector<at::ScalarType> dtypes;
3124-
dtypes.reserve(output_dtypes.size());
3125-
for (auto& dtype : output_dtypes) {
3126-
dtypes.push_back(
3127-
reinterpret_cast<THPDtype*>(dtype.ptr())->scalar_type);
3128-
}
31293108

3130-
XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> xla_inputs, bridge::GetXlaTensors(inputs));
3131-
auto xtensors = tensor_methods::custom_call(
3132-
xla_inputs, target,
3133-
output_shapes, dtypes, has_side_effect, backend_config,
3134-
api_version, frontend_attributes);
3135-
return bridge::AtenFromXlaTensors(std::move(xtensors));
3109+
XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> xla_inputs,
3110+
bridge::GetXlaTensors(inputs));
3111+
XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> xla_outputs,
3112+
tensor_methods::custom_call(
3113+
xla_inputs, target, output_shapes, output_dtypes,
3114+
has_side_effect, backend_config, api_version,
3115+
frontend_attributes));
3116+
3117+
return bridge::AtenFromXlaTensors(std::move(xla_outputs));
31363118
})
31373119
.def("_xla_tpu_custom_call",
31383120
[](const std::vector<at::Tensor>& inputs, const std::string& payload,
31393121
const std::vector<std::vector<int64_t>>& output_shapes,
3140-
const std::vector<py::object>& output_dtypes)
3122+
const std::vector<at::ScalarType>& output_dtypes)
31413123
-> std::vector<at::Tensor> {
3142-
return TpuCustomCall(inputs, payload, output_shapes, output_dtypes);
3124+
3125+
XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> xla_inputs,
3126+
bridge::GetXlaTensors(inputs));
3127+
XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> xla_outputs,
3128+
tensor_methods::tpu_custom_call(xla_inputs, payload, output_shapes, output_dtypes));
3129+
3130+
return bridge::AtenFromXlaTensors(std::move(xla_outputs));
31433131
})
31443132
.def("_xla_register_custom_call_target",
31453133
[](const std::string& fn_name, const py::capsule& function_ptr,

torch_xla/csrc/tensor_methods.cpp

Lines changed: 98 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -299,27 +299,6 @@ absl::StatusOr<PoolNdInputsOwner> FillAndCheckPoolNdInputs(
299299
return PoolNdInputsOwner{kernel_size, stride, padding};
300300
}
301301

302-
// Resizes and / or checks whether a list is of the given size. The list is only
303-
// resized if its size is 1. If it's empty, it's replaced with the provided
304-
// default first.
305-
std::vector<int64_t> CheckIntList(absl::Span<const int64_t> list, size_t length,
306-
const std::string& name,
307-
std::vector<int64_t> def = {}) {
308-
std::vector<int64_t> result;
309-
if (list.empty()) {
310-
result = std::move(def);
311-
} else {
312-
result = torch::lazy::ToVector<int64_t>(list);
313-
}
314-
if (result.size() == 1 && length > 1) {
315-
result.resize(length, result[0]);
316-
return result;
317-
}
318-
XLA_CHECK_EQ(result.size(), length)
319-
<< "Invalid length for the '" << name << "' attribute";
320-
return result;
321-
}
322-
323302
// Returns a 1-D shape for batch norm weight or bias based on the input shape.
324303
xla::Shape BatchNormFeaturesShape(const XLATensorPtr& input) {
325304
xla::PrimitiveType input_element_type =
@@ -666,6 +645,75 @@ absl::Status CheckUniformRangeIsValid(double from, double to) {
666645
return absl::OkStatus();
667646
}
668647

648+
absl::Status CheckCustomCallNonEmptyInputs(
649+
const std::vector<absl_nonnull XLATensorPtr>& inputs,
650+
const std::string& target) {
651+
if (inputs.empty()) {
652+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
653+
"custom_call(", target, "): expected at least 1 input tensor.")));
654+
}
655+
return absl::OkStatus();
656+
}
657+
658+
absl::Status CheckCustomCallOutputPropertiesSize(
659+
const std::vector<std::vector<int64_t>>& output_shapes,
660+
const std::vector<at::ScalarType>& output_dtypes,
661+
const std::string& target) {
662+
if (output_shapes.size() != output_dtypes.size()) {
663+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
664+
"custom_call(", target,
665+
"): expected the given output shapes (size=", output_shapes.size(),
666+
") to be of the same size as the given output dtypes (size=",
667+
output_dtypes.size(), ").")));
668+
}
669+
return absl::OkStatus();
670+
}
671+
672+
template <class F>
673+
absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> CustomCallImpl(
674+
const std::vector<absl_nonnull XLATensorPtr>& inputs,
675+
const std::string& target,
676+
const std::vector<std::vector<int64_t>>& output_shapes,
677+
const std::vector<at::ScalarType>& output_dtypes, F&& make_node) {
678+
XLA_RETURN_IF_ERROR(CheckCustomCallNonEmptyInputs(inputs, target));
679+
XLA_RETURN_IF_ERROR(CheckCustomCallOutputPropertiesSize(
680+
output_shapes, output_dtypes, target));
681+
682+
const auto& first = inputs.front();
683+
auto device = first->GetDevice();
684+
auto output_range = c10::irange(output_shapes.size());
685+
686+
// `values`: vector with Lazy IR of `inputs`.
687+
std::vector<torch::lazy::Value> values(inputs.size());
688+
std::transform(
689+
inputs.begin(), inputs.end(), values.begin(),
690+
[](const XLATensorPtr& tensor) { return tensor->GetIrValue(); });
691+
692+
// `output_xla_shapes`: `xla::Shape` instances created from `output_shapes`
693+
// and `output_dtypes`.
694+
std::vector<xla::Shape> output_xla_shapes(output_shapes.size());
695+
std::transform(output_range.begin(), output_range.end(),
696+
output_xla_shapes.begin(), [&](std::size_t i) {
697+
return xla::ShapeUtil::MakeShape(
698+
MakeXlaPrimitiveType(output_dtypes[i], &device),
699+
output_shapes[i]);
700+
});
701+
702+
auto node = make_node(values, output_xla_shapes);
703+
704+
// `outputs`: `XLATensorPtr` instances created from the `i`-th output of
705+
// the `node` Lazy IR `Node`.
706+
std::vector<XLATensorPtr> outputs(output_shapes.size());
707+
std::transform(output_range.begin(), output_range.end(), outputs.begin(),
708+
[&](std::size_t i) {
709+
return first->CreateFrom(torch::lazy::Value(node, i),
710+
output_dtypes[i],
711+
/*delay_eager_execution=*/true);
712+
});
713+
714+
return outputs;
715+
}
716+
669717
} // namespace
670718

671719
//////////////////////////////////////////////////////////////////////////////
@@ -886,40 +934,26 @@ std::pair<XLATensorPtr, torch::lazy::Value> collective_permute(
886934
torch::lazy::Value(node, 1)};
887935
}
888936

889-
std::vector<XLATensorPtr> custom_call(
890-
const std::vector<XLATensorPtr>& inputs, const std::string& target,
937+
absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> custom_call(
938+
const std::vector<absl_nonnull XLATensorPtr>& inputs,
939+
const std::string& target,
891940
const std::vector<std::vector<int64_t>>& output_shapes,
892941
const std::vector<at::ScalarType>& output_dtypes, bool has_side_effect,
893942
const std::string& backend_config, const int api_version,
894943
const std::unordered_map<std::string, std::string>& frontend_attributes) {
895-
XLA_CHECK(inputs.size() > 0) << "inputs are empty";
896-
897-
std::vector<torch::lazy::Value> values;
898-
values.reserve(inputs.size());
899-
for (const auto& input : inputs) {
900-
values.push_back(input->GetIrValue());
901-
}
902-
903-
XLA_CHECK_EQ(output_shapes.size(), output_dtypes.size());
904-
std::vector<xla::Shape> output_xla_shapes;
905-
output_xla_shapes.reserve(output_shapes.size());
906-
for (size_t i = 0; i < output_shapes.size(); ++i) {
907-
output_xla_shapes.push_back(xla::ShapeUtil::MakeShape(
908-
MakeXlaPrimitiveType(output_dtypes[i], &(inputs[0]->GetDevice())),
909-
output_shapes[i]));
910-
}
911-
912-
auto node = torch_xla::MakeNode<CustomCall>(
913-
values, target, xla::ShapeUtil::MakeTupleShape(output_xla_shapes),
914-
has_side_effect, backend_config, api_version, frontend_attributes);
944+
XLA_ASSIGN_OR_RETURN(
945+
std::vector<absl_nonnull XLATensorPtr> outputs,
946+
CustomCallImpl(inputs, target, output_shapes, output_dtypes,
947+
/* make_node= */
948+
[&](const std::vector<torch::lazy::Value>& values,
949+
const std::vector<xla::Shape>& output_xla_shapes) {
950+
return torch_xla::MakeNode<CustomCall>(
951+
values, target,
952+
xla::ShapeUtil::MakeTupleShape(output_xla_shapes),
953+
has_side_effect, backend_config, api_version,
954+
frontend_attributes);
955+
}));
915956

916-
std::vector<XLATensorPtr> outputs;
917-
outputs.reserve(output_shapes.size());
918-
for (size_t i = 0; i < output_shapes.size(); ++i) {
919-
outputs.push_back(inputs[0]->CreateFrom(torch::lazy::Value(node, i),
920-
output_dtypes[i],
921-
/*delay_eager_execution=*/true));
922-
}
923957
XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get();
924958
if (graph_executor->UseEagerMode()) {
925959
// Execute the HLO that will run the `customcall` and in one graph
@@ -954,37 +988,23 @@ void custom_sharding_(
954988
input->SetShardingSpec(*sharding_spec);
955989
}
956990

957-
std::vector<XLATensorPtr> tpu_custom_call(
958-
const std::vector<XLATensorPtr>& inputs, const std::string& payload,
991+
absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> tpu_custom_call(
992+
const std::vector<absl_nonnull XLATensorPtr>& inputs,
993+
const std::string& payload,
959994
const std::vector<std::vector<int64_t>>& output_shapes,
960995
const std::vector<at::ScalarType>& output_dtypes) {
961-
XLA_CHECK(inputs.size() > 0) << "inputs are empty";
962-
963-
std::vector<torch::lazy::Value> values;
964-
values.reserve(inputs.size());
965-
for (const auto& input : inputs) {
966-
values.push_back(input->GetIrValue());
967-
}
968-
969-
XLA_CHECK_EQ(output_shapes.size(), output_dtypes.size());
970-
std::vector<xla::Shape> output_xla_shapes;
971-
output_xla_shapes.reserve(output_shapes.size());
972-
for (size_t i = 0; i < output_shapes.size(); ++i) {
973-
output_xla_shapes.push_back(xla::ShapeUtil::MakeShape(
974-
MakeXlaPrimitiveType(output_dtypes[i], &(inputs[0]->GetDevice())),
975-
output_shapes[i]));
976-
}
977-
978-
auto node = torch_xla::MakeNode<TpuCustomCall>(
979-
values, xla::ShapeUtil::MakeTupleShape(output_xla_shapes), payload);
996+
XLA_ASSIGN_OR_RETURN(
997+
std::vector<absl_nonnull XLATensorPtr> outputs,
998+
CustomCallImpl(inputs, payload, output_shapes, output_dtypes,
999+
/* make_node= */
1000+
[&](const std::vector<torch::lazy::Value>& values,
1001+
const std::vector<xla::Shape>& output_xla_shapes) {
1002+
return torch_xla::MakeNode<TpuCustomCall>(
1003+
values,
1004+
xla::ShapeUtil::MakeTupleShape(output_xla_shapes),
1005+
payload);
1006+
}));
9801007

981-
std::vector<XLATensorPtr> outputs;
982-
outputs.reserve(output_shapes.size());
983-
for (size_t i = 0; i < output_shapes.size(); ++i) {
984-
outputs.push_back(inputs[0]->CreateFrom(torch::lazy::Value(node, i),
985-
output_dtypes[i],
986-
/*delay_eager_execution=*/true));
987-
}
9881008
XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get();
9891009
if (graph_executor->UseEagerMode()) {
9901010
// Execute the HLO that will run the `custom` and in one hlo

torch_xla/csrc/tensor_methods.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,9 @@ std::pair<XLATensorPtr, torch::lazy::Value> collective_permute(
9292
const XLATensorPtr& input, const torch::lazy::Value& token,
9393
std::vector<std::pair<int64_t, int64_t>> source_target_pairs);
9494

95-
std::vector<XLATensorPtr> custom_call(
96-
const std::vector<XLATensorPtr>& inputs, const std::string& target,
95+
absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> custom_call(
96+
const std::vector<absl_nonnull XLATensorPtr>& inputs,
97+
const std::string& target,
9798
const std::vector<std::vector<int64_t>>& output_shapes,
9899
const std::vector<at::ScalarType>& output_dtypes, bool has_side_effect,
99100
const std::string& backend_config, const int api_version,
@@ -104,8 +105,9 @@ void custom_sharding_(
104105
const std::shared_ptr<XLATensor::ShardingSpec>& spec,
105106
const CustomSharding::Type& type = CustomSharding::Type::kSharding);
106107

107-
std::vector<XLATensorPtr> tpu_custom_call(
108-
const std::vector<XLATensorPtr>& inputs, const std::string& payload,
108+
absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> tpu_custom_call(
109+
const std::vector<absl_nonnull XLATensorPtr>& inputs,
110+
const std::string& payload,
109111
const std::vector<std::vector<int64_t>>& output_shapes,
110112
const std::vector<at::ScalarType>& output_dtypes);
111113

0 commit comments

Comments
 (0)