Skip to content

Commit 323e848

Browse files
committed
Deprecate runtime.xla_device in favor of xla_model.xla_device
1 parent 358c862 commit 323e848

File tree

5 files changed

+19
-34
lines changed

5 files changed

+19
-34
lines changed

test/spmd/test_xla_sharding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1484,10 +1484,10 @@ def test_xla_patched_linear(self):
14841484
"""
14851485

14861486
from torch_xla.distributed.spmd.xla_sharding import XLAPatchedLinear
1487-
import torch_xla.runtime
1487+
import torch_xla.core.xla_model as xm
14881488
import torch.nn.functional as F
14891489

1490-
with torch_xla.runtime.xla_device():
1490+
with xm.xla_device():
14911491
torch_xla.manual_seed(42)
14921492
x0 = torch.randn(2, 3, requires_grad=True)
14931493
w0 = torch.randn(4, 3, requires_grad=True)

torch_xla/core/xla_model.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,12 @@ def xla_device(n: Optional[int] = None,
142142
Args:
143143
n (int, optional): The specific instance (ordinal) to be returned. If
144144
specified, the specific XLA device instance will be returned. Otherwise
145-
the first device of `devkind` will be returned.
145+
the first device (default 0) will be returned.
146146
devkind (string..., optional): If specified, device type such as `TPU`,
147147
`CUDA`, `CPU`, or custom PJRT device. Deprecated.
148148
149149
Returns:
150-
A `torch.device` with the requested instance.
150+
A `torch.device` with the requested instance of an XLA device.
151151
"""
152152
# When SPMD is enabled, we always return `xla:0` to the user, and
153153
# under the hood we use virtual device logic for every xla tensor
@@ -156,7 +156,16 @@ def xla_device(n: Optional[int] = None,
156156
torch_xla._XLAC._xla_set_default_device(device)
157157
return torch.device(device)
158158

159-
return runtime.xla_device(n, devkind)
159+
if n is None:
160+
return torch.device(torch_xla._XLAC._xla_get_default_device())
161+
162+
devices = xm.get_xla_supported_devices(devkind=devkind)
163+
if n > len(devices):
164+
raise IndexError('Device index {} out of range in {}'.format(n, devices))
165+
166+
device = devices[n]
167+
torch_xla._XLAC._xla_set_default_device(device)
168+
return torch.device(device)
160169

161170

162171
def _xla_real_device(device: torch.device) -> Any:

torch_xla/experimental/scan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def scan(fn, init, xs):
120120
Example:
121121
122122
>>> # Example of using `scan` to implement `torch.cumsum`.
123-
>>> import torch_xla.runtime
123+
>>> import torch_xla.core.xla_model as xm
124124
>>> import torch
125125
>>> from torch_xla.experimental.scan import scan
126126
>>>
@@ -129,7 +129,7 @@ def scan(fn, init, xs):
129129
>>> y = new_carry
130130
>>> return new_carry, y
131131
>>>
132-
>>> with torch_xla.runtime.xla_device():
132+
>>> with xm.xla_device():
133133
>>> init = torch.tensor([0.0, 0.0], requires_grad=True)
134134
>>> xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
135135
>>> requires_grad=True)

torch_xla/experimental/scan_layers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ def scan_layers(layers: Iterable[torch.nn.Module],
4747
4848
Example:
4949
50-
>>> import torch_xla.runtime
50+
>>> import torch_xla.core.xla_model as xm
5151
>>> import torch
5252
>>> import torch.nn as nn
5353
>>> from torch_xla.experimental.scan_layers import scan_layers
54-
>>> with torch_xla.runtime.xla_device():
54+
>>> with xm.xla_device():
5555
>>> layers = [nn.Linear(16, 16) for i in range(10)]
5656
>>> input = torch.randn(16)
5757
>>> output = scan_layers(layers, input)

torch_xla/runtime.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -103,30 +103,6 @@ def is_bf16_supported():
103103
return False
104104

105105

106-
def xla_device(n: Optional[int] = None,
107-
devkind: Optional[str] = None) -> torch.device:
108-
"""Returns an XLA device.
109-
110-
Args:
111-
n: Index of XLA device within visibible devices. If not set, use local
112-
ordinal (default 0) to select an addressable device.
113-
devkind: Type of device to return. Should match `device_type()`.
114-
115-
Returns:
116-
A `torch.device` representing an XLA device.
117-
"""
118-
if n is None:
119-
return torch.device(torch_xla._XLAC._xla_get_default_device())
120-
121-
devices = xm.get_xla_supported_devices(devkind=devkind)
122-
if n > len(devices):
123-
raise IndexError('Device index {} out of range in {}'.format(n, devices))
124-
125-
device = devices[n]
126-
torch_xla._XLAC._xla_set_default_device(device)
127-
return torch.device(device)
128-
129-
130106
def local_process_count() -> int:
131107
"""Returns the number of processes running on this host."""
132108
return xu.getenv_as(xenv.PJRT_LOCAL_PROCESS_COUNT, int, defval=1)
@@ -180,7 +156,7 @@ def local_ordinal() -> int:
180156
Local ordinal is in range [0, local_device_count)."""
181157
local_rank = xu.getenv_as(xenv.PJRT_LOCAL_PROCESS_RANK, int, 0)
182158
devices_per_process = addressable_device_count()
183-
return local_rank * devices_per_process + xla_device().index
159+
return local_rank * devices_per_process + xm.xla_device().index
184160

185161

186162
def process_index() -> int:

0 commit comments

Comments
 (0)