Skip to content

Commit 60b408d

Browse files
authored
Merge branch 'pytorch:master' into generator
2 parents 80b2078 + df6798d commit 60b408d

File tree

7 files changed

+204
-114
lines changed

7 files changed

+204
-114
lines changed

.github/workflows/build_and_test.yml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,14 @@ jobs:
4242
with:
4343
dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.12_tpuvm
4444
timeout-minutes: 45 # Takes ~20m as of 2025/5/30.
45-
has_code_changes: ${{ needs.check_code_changes.outputs.has_code_changes }}
45+
# We should build PyTorch and PyTorch/XLA if:
46+
# 1. There are code changes.
47+
# 2. This is a `push` event to `master` or release branches.
48+
#
49+
# The reason for (2) is that `push-docs` job below is run precisely on (2) condition.
50+
# In order for it (the `push-docs` job) to be successful, it needs to install the PyTorch
51+
# and PyTorch/XLA wheels. Therefore, we need to build their wheels by running this job.
52+
has_code_changes: ${{ (needs.check_code_changes.outputs.has_code_changes == 'true' || github.event_name == 'push') && 'true' || 'false' }}
4653
runner: linux.24xlarge
4754
secrets:
4855
gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }}

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# PyTorch/XLA
22

3+
> [!NOTE]
4+
> <b>10/2025</b>: Based on community feedback, we have proposed a more native direction for PyTorch on TPU. Read the RFC and comment at [#9684](https://github.com/pytorch/xla/issues/9684).
5+
>
6+
37
<b>Current CI status:</b> ![GitHub Actions
48
status](https://github.com/pytorch/xla/actions/workflows/build_and_test.yml/badge.svg)
59

test/stablehlo/test_stablehlo_custom_call.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import expecttest
12
import sys
23
import re
34
import unittest
@@ -16,7 +17,7 @@
1617
m = Library("my_custom_library", "DEF")
1718

1819

19-
class StableHLOCustomCallExportTest(unittest.TestCase):
20+
class StableHLOCustomCallExportTest(expecttest.TestCase):
2021

2122
def test_single_output(self):
2223

test/test_ops_error_message.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Callable
12
import expecttest
23
import os
34
import torch
@@ -357,6 +358,56 @@ def gen_test_fn(kernel_size=[2, 2, 2], stride=[], padding=[0]):
357358
expect="""avg_pool3d(): expected argument padding [1, 2] (size: 2) to have size of 3."""
358359
)
359360

361+
def _get_custom_call_properties(self, mode):
362+
match mode:
363+
case "tpu":
364+
return (torch_xla._XLAC._xla_tpu_custom_call, "", [])
365+
case "stablehlo":
366+
return (torch_xla._XLAC._xla_custom_call, "custom_op_target",
367+
[False, "", 0, {}])
368+
369+
self.fail(f"expected `mode` ({mode}) to be either of ['tpu', 'stablehlo'].")
370+
371+
def _gen_custom_call_no_input(self, mode):
372+
lib_custom_call, payload, args = self._get_custom_call_properties(
373+
mode) # type: ignore[attr-defined]
374+
return lambda: lib_custom_call([], payload, [[1]], [torch.int8], *args)
375+
376+
def _gen_custom_call_output_properties_size_mismatch(self, mode):
377+
lib_custom_call, payload, args = self._get_custom_call_properties(
378+
mode) # type: ignore[attr-defined]
379+
input = torch.rand(10, device=torch_xla.device())
380+
return lambda: lib_custom_call(
381+
(input,), payload, [[1], [1]], [torch.int8], *args)
382+
383+
def test_stablehlo_custom_call(self):
384+
385+
self.assertExpectedRaisesInline(
386+
exc_type=RuntimeError,
387+
callable=self._gen_custom_call_no_input("stablehlo"),
388+
expect="""custom_call(custom_op_target): expected at least 1 input tensor."""
389+
)
390+
391+
self.assertExpectedRaisesInline(
392+
exc_type=RuntimeError,
393+
callable=self._gen_custom_call_output_properties_size_mismatch(
394+
"stablehlo"),
395+
expect="""custom_call(custom_op_target): expected the given output shapes (size=2) to be of the same size as the given output dtypes (size=1)."""
396+
)
397+
398+
def test_tpu_custom_call(self):
399+
400+
self.assertExpectedRaisesInline(
401+
exc_type=RuntimeError,
402+
callable=self._gen_custom_call_no_input("tpu"),
403+
expect="""tpu_custom_call(): expected at least 1 input tensor.""")
404+
405+
self.assertExpectedRaisesInline(
406+
exc_type=RuntimeError,
407+
callable=self._gen_custom_call_output_properties_size_mismatch("tpu"),
408+
expect="""tpu_custom_call(): expected the given output shapes (size=2) to be of the same size as the given output dtypes (size=1)."""
409+
)
410+
360411

361412
if __name__ == "__main__":
362413
unittest.main()

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -347,21 +347,6 @@ std::vector<std::vector<int64_t>> CreateReduceGroups(const py::list& groups) {
347347
return replica_groups;
348348
}
349349

350-
std::vector<at::Tensor> TpuCustomCall(
351-
const std::vector<at::Tensor>& inputs, const std::string& payload,
352-
const std::vector<std::vector<int64_t>>& output_shapes,
353-
const std::vector<py::object>& output_dtypes) {
354-
std::vector<at::ScalarType> dtypes;
355-
dtypes.reserve(output_dtypes.size());
356-
for (auto& dtype : output_dtypes) {
357-
dtypes.push_back(reinterpret_cast<THPDtype*>(dtype.ptr())->scalar_type);
358-
}
359-
XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> xla_inputs,
360-
bridge::GetXlaTensors(inputs));
361-
return bridge::AtenFromXlaTensors(tensor_methods::tpu_custom_call(
362-
xla_inputs, payload, output_shapes, dtypes));
363-
}
364-
365350
std::vector<std::vector<int>> ExtractXlaDotGeneralDimVectors(
366351
const py::tuple& dimension_numbers) {
367352
// Expect Python arg `dimension_numbers` to be
@@ -3116,30 +3101,33 @@ void InitXlaModuleBindings(py::module m) {
31163101
"_xla_custom_call",
31173102
[](const std::vector<at::Tensor>& inputs, const std::string& target,
31183103
const std::vector<std::vector<int64_t>>& output_shapes,
3119-
const std::vector<py::object>& output_dtypes, bool has_side_effect,
3104+
const std::vector<at::ScalarType>& output_dtypes, bool has_side_effect,
31203105
const std::string& backend_config, const int api_version,
31213106
const std::unordered_map<std::string, std::string>&
31223107
frontend_attributes) -> std::vector<at::Tensor> {
3123-
std::vector<at::ScalarType> dtypes;
3124-
dtypes.reserve(output_dtypes.size());
3125-
for (auto& dtype : output_dtypes) {
3126-
dtypes.push_back(
3127-
reinterpret_cast<THPDtype*>(dtype.ptr())->scalar_type);
3128-
}
31293108

3130-
XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> xla_inputs, bridge::GetXlaTensors(inputs));
3131-
auto xtensors = tensor_methods::custom_call(
3132-
xla_inputs, target,
3133-
output_shapes, dtypes, has_side_effect, backend_config,
3134-
api_version, frontend_attributes);
3135-
return bridge::AtenFromXlaTensors(std::move(xtensors));
3109+
XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> xla_inputs,
3110+
bridge::GetXlaTensors(inputs));
3111+
XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> xla_outputs,
3112+
tensor_methods::custom_call(
3113+
xla_inputs, target, output_shapes, output_dtypes,
3114+
has_side_effect, backend_config, api_version,
3115+
frontend_attributes));
3116+
3117+
return bridge::AtenFromXlaTensors(std::move(xla_outputs));
31363118
})
31373119
.def("_xla_tpu_custom_call",
31383120
[](const std::vector<at::Tensor>& inputs, const std::string& payload,
31393121
const std::vector<std::vector<int64_t>>& output_shapes,
3140-
const std::vector<py::object>& output_dtypes)
3122+
const std::vector<at::ScalarType>& output_dtypes)
31413123
-> std::vector<at::Tensor> {
3142-
return TpuCustomCall(inputs, payload, output_shapes, output_dtypes);
3124+
3125+
XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> xla_inputs,
3126+
bridge::GetXlaTensors(inputs));
3127+
XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> xla_outputs,
3128+
tensor_methods::tpu_custom_call(xla_inputs, payload, output_shapes, output_dtypes));
3129+
3130+
return bridge::AtenFromXlaTensors(std::move(xla_outputs));
31433131
})
31443132
.def("_xla_register_custom_call_target",
31453133
[](const std::string& fn_name, const py::capsule& function_ptr,

0 commit comments

Comments
 (0)