From c6621b1080829ddff0b49f25adb933f27bcef002 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Wed, 11 Jun 2025 23:23:10 +0000 Subject: [PATCH] Clarify `xr.device_type()` API and use them --- test/ds/test_dynamic_shape_models.py | 2 +- test/pjrt/test_dynamic_plugin_tpu.py | 2 +- test/test_autocast.py | 5 ++--- torch_xla/distributed/spmd/xla_sharding.py | 2 +- torch_xla/runtime.py | 4 ++-- 5 files changed, 7 insertions(+), 8 deletions(-) diff --git a/test/ds/test_dynamic_shape_models.py b/test/ds/test_dynamic_shape_models.py index 114c41e5c829..36bb87b3876d 100644 --- a/test/ds/test_dynamic_shape_models.py +++ b/test/ds/test_dynamic_shape_models.py @@ -44,7 +44,7 @@ def forward(self, x): @unittest.skipIf( - xm.xla_device_hw(torch_xla.device()) != 'TPU', + xr.device_type() != 'TPU', f"The tests fail on CPU. See https://github.com/pytorch/xla/issues/4298 for more detail." ) class TestDynamicShapeModels(unittest.TestCase): diff --git a/test/pjrt/test_dynamic_plugin_tpu.py b/test/pjrt/test_dynamic_plugin_tpu.py index f199797afc24..151207f6bc1a 100644 --- a/test/pjrt/test_dynamic_plugin_tpu.py +++ b/test/pjrt/test_dynamic_plugin_tpu.py @@ -20,7 +20,7 @@ def setUpClass(cls): @staticmethod def _assert_tpus_exist(index=0): del index - assert xm.xla_device_hw(torch_xla.device()) == 'TPU' + assert xr.device_type() == 'TPU' def test_single_process(self): with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: diff --git a/test/test_autocast.py b/test/test_autocast.py index ca1f26c05ec1..19101e276594 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -348,8 +348,7 @@ def compare(first, second): self.assertFalse(self.is_autocast_enabled()) -@unittest.skipIf( - xm.xla_device_hw(torch_xla.device()) != 'TPU', f"TPU autocast test.") +@unittest.skipIf(xr.device_type() != 'TPU', f"TPU autocast test.") class TestAutocastTPU(TestAutocastBase): @classmethod @@ -405,7 +404,7 @@ class TestOtherOps(unittest.TestCase): # On TPU, the input of batch norm is casted into fp32, see torch_xla/csrc/autocast_mode.cpp @unittest.skipIf( - xm.xla_device_hw(torch_xla.device()) != 'TPU', + xr.device_type() != 'TPU', "the behavior of batch_norm autocast on TPU is different from others") def test_batch_norm_tpu(self): device = torch_xla.device() diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index 49229b17cffe..d6c2ef57c3e6 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -334,7 +334,7 @@ def _get_physical_tpu_mesh(self, devices: np.ndarray) -> np.ndarray: A np.ndarray of device logical ordinals with shape [global_x, global_y, global_z]. On v2 and v3, global_z is instead cores_per_chip (i.e., 2). """ - assert xm.xla_device_hw(torch_xla.device()) == 'TPU' + assert xr.device_type() == 'TPU' # coords is a 3-dims tuple representing the device in physical mesh device_coords = [self.device_attributes[d]['coords'] for d in devices] dims = tuple(d + 1 for d in max(device_coords)) diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index 2e274190db75..31465888098d 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -82,12 +82,12 @@ def _maybe_select_default_device(): def device_type() -> Optional[str]: - """Returns the current PjRt device type. + """Returns the current PJRT device type. Selects a default device if none has been configured Returns: - A string representation of the device. + A string representation of the PJRT device: "CPU", "TPU", etc. """ pjrt_device = xu.getenv_as(xenv.PJRT_DEVICE, str) return pjrt_device.split('_')[0] if pjrt_device else pjrt_device