Skip to content

Migrate runtime.xla_device in favor of core.xla_model.xla_device #9200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions API_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ import torch
import torch_xla
import torch_xla.core.xla_model as xm

t = torch.randn(2, 2, device=xm.xla_device())
t = torch.randn(2, 2, device=torch_xla.device())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be changed to device="xla"?

print(t.device)
print(t)
```

This code should look familiar. PyTorch/XLA uses the same interface as regular
PyTorch with a few additions. Importing `torch_xla` initializes PyTorch/XLA, and
`xm.xla_device()` returns the current XLA device. This may be a CPU or TPU
`torch_xla.device()` returns the current XLA device. This may be a CPU or TPU
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be further simplified to simply describing that users get a devicea via torch.device("xla")?

depending on your environment.

## XLA Tensors are PyTorch Tensors
Expand All @@ -32,8 +32,8 @@ PyTorch operations can be performed on XLA tensors just like CPU or CUDA tensors
For example, XLA tensors can be added together:

```python
t0 = torch.randn(2, 2, device=xm.xla_device())
t1 = torch.randn(2, 2, device=xm.xla_device())
t0 = torch.randn(2, 2, device=torch_xla.device())
t1 = torch.randn(2, 2, device=torch_xla.device())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be device="xla"?

print(t0 + t1)
```

Expand All @@ -46,8 +46,8 @@ print(t0.mm(t1))
Or used with neural network modules:

```python
l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20).to(xm.xla_device())
l_in = torch.randn(10, device=torch_xla.device())
linear = torch.nn.Linear(10, 20).to(torch_xla.device())
l_out = linear(l_in)
print(l_out)
```
Expand All @@ -56,7 +56,7 @@ Like other device types, XLA tensors only work with other XLA tensors on the
same device. So code like

```python
l_in = torch.randn(10, device=xm.xla_device())
l_in = torch.randn(10, device=torch_xla.device())
linear = torch.nn.Linear(10, 20)
l_out = linear(l_in)
print(l_out)
Expand Down Expand Up @@ -109,10 +109,10 @@ class MNIST(nn.Module):
batch_size = 128
train_loader = xu.SampleGenerator(
data=(torch.zeros(batch_size, 1, 28, 28),
torch.zeros(batch_size, dtype=torch.int64)),
torch.zeros(batch_size, dtype=torch.int64)),
sample_count=60000 // batch_size // xr.world_size())

device = xm.xla_device() # Get the XLA device (TPU).
device = torch_xla.device() # Get the XLA device (TPU).
model = MNIST().train().to(device) # Create a model and move it to the device.
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
Expand Down Expand Up @@ -169,7 +169,7 @@ def _mp_fn(index):
index: Index of the process.
"""

device = xm.xla_device() # Get the device assigned to this process.
device = torch_xla.device() # Get the device assigned to this process.
# Wrap the loader for multi-device.
mp_device_loader = pl.MpDeviceLoader(train_loader, device)

Expand Down Expand Up @@ -197,7 +197,7 @@ single device snippet. Let's go over then one by one.
- `torch_xla.launch()`
- Creates the processes that each run an XLA device.
- This function is a wrapper of multithreading spawn to allow user run the script with torchrun command line also. Each process will only be able to access the device assigned to the current process. For example on a TPU v4-8, there will be 4 processes being spawn up and each process will own a TPU device.
- Note that if you print the `xm.xla_device()` on each process you will see `xla:0` on all devices. This is because each process can only see one device. This does not mean multi-process is not functioning. The only exeption is with PJRT runtime on TPU v2 and TPU v3 since there will be `#devices/2` processes and each process will have 2 threads (check this [doc](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#tpus-v2v3-vs-v4) for more details).
- Note that if you print the `torch_xla.device()` on each process you will see `xla:0` on all devices. This is because each process can only see one device. This does not mean multi-process is not functioning. The only exeption is with PJRT runtime on TPU v2 and TPU v3 since there will be `#devices/2` processes and each process will have 2 threads (check this [doc](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#tpus-v2v3-vs-v4) for more details).
- `MpDeviceLoader`
- Loads the training data onto each device.
- `MpDeviceLoader` can wrap on a torch dataloader. It can preload the data to the device and overlap the dataloading with device execution to improve the performance.
Expand Down Expand Up @@ -290,7 +290,7 @@ import torch
import torch_xla
import torch_xla.core.xla_model as xm

device = xm.xla_device()
device = torch_xla.device()

t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ If you're using `DistributedDataParallel`, make the following changes:
+ # Rank and world size are inferred from the XLA device runtime
+ dist.init_process_group("xla", init_method='xla://')
+
+ model.to(xm.xla_device())
+ model.to(torch_xla.device())
+ ddp_model = DDP(model, gradient_as_bucket_view=True)

- model = model.to(rank)
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/benchmark_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def update_process_env(self, process_env: Dict[str, str]):
def get_device(self):
if self.torch_xla2:
# Initiate the model in CPU first for xla2. We will move the model to jax device later.
# This is because we don't have xm.xla_device() function in torch_xla2.
# This is because we don't have torch_xla.device() function in torch_xla2.
return torch.device("cpu")
if self.xla:
return xm.xla_device(devkind=self.accelerator.upper())
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/experiment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def _default_iter_fn(self, benchmark_experiment: BenchmarkExperiment,

def _pure_wall_time_iter_fn(self, benchmark_experiment: BenchmarkExperiment,
benchmark_model: BenchmarkModel, input_tensor):
device = xm.xla_device() if benchmark_experiment.xla else 'cuda'
device = torch_xla.device() if benchmark_experiment.xla else 'cuda'
sync_fn = xm.wait_device_ops if benchmark_experiment.xla else torch.cuda.synchronize
timing, output = bench.do_bench(
lambda: benchmark_model.model_iter_fn(
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/matmul_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def main():
fn,
return_mode='min',
sync_fn=lambda: xm.wait_device_ops(),
device=xm.xla_device())
device=torch_xla.device())
ind_bench_fn = lambda fn: do_bench(
fn,
return_mode='min',
Expand All @@ -53,7 +53,7 @@ def main():
for dtype in dtypes:
for inductor_matmul, xla_matmul in zip(
get_matmuls(device='cuda', dtype=dtype, backend='inductor'),
get_matmuls(device=xm.xla_device(), dtype=dtype, backend='openxla')):
get_matmuls(device=torch_xla.device(), dtype=dtype, backend='openxla')):
ind_lhs_shape, ind_rhs_shape, ind_fn = inductor_matmul
xla_lhs_shape, xla_rhs_shape, xla_fn = xla_matmul
assert ind_lhs_shape == xla_lhs_shape, f"Expect matmul shapes to match for benchmarking. Mismatch lhs: {ind_lhs_shape}, rhs: {xla_rhs_shape}"
Expand Down
10 changes: 5 additions & 5 deletions contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"To get the current process/thread's default XLA device, use `xm.xla_device()`. XLA devices are numbered as `xla:i`, where `i` is the index of the device within the current process. Since each process has two devices on a TPU v3, this will be `xla:0` or `xla:1`."
"To get the current process/thread's default XLA device, use `torch_xla.device()`. XLA devices are numbered as `xla:i`, where `i` is the index of the device within the current process. Since each process has two devices on a TPU v3, this will be `xla:0` or `xla:1`."
]
},
{
Expand All @@ -210,7 +210,7 @@
"lock = mp.Manager().Lock()\n",
"\n",
"def print_device(i, lock):\n",
" device = xm.xla_device()\n",
" device = torch_xla.device()\n",
" with lock:\n",
" print('process', i, device)"
]
Expand Down Expand Up @@ -318,7 +318,7 @@
],
"source": [
"def add_ones(i, lock):\n",
" x = torch.ones((3, 3), device=xm.xla_device())\n",
" x = torch.ones((3, 3), device=torch_xla.device())\n",
" y = x + x\n",
" \n",
" # Run graph to compute `y` before printing\n",
Expand Down Expand Up @@ -378,7 +378,7 @@
"source": [
"def gather_ids(i, lock):\n",
" # Create a tensor on each device with the device ID\n",
" t = torch.tensor([i], device=xm.xla_device())\n",
" t = torch.tensor([i], device=torch_xla.device())\n",
" with lock:\n",
" print(i, t)\n",
" \n",
Expand Down Expand Up @@ -454,7 +454,7 @@
"import torch_xla.experimental.pjrt_backend # Required for torch.distributed on TPU v2 and v3\n",
"\n",
"def toy_model(index, lock):\n",
" device = xm.xla_device()\n",
" device = torch_xla.device()\n",
" dist.init_process_group('xla', init_method='xla://')\n",
"\n",
" # Initialize a basic toy model\n",
Expand Down
2 changes: 1 addition & 1 deletion contrib/kaggle/pytorch-xla-2-0-on-kaggle.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@
"\n",
"pipeline = DiffusionPipeline.from_pretrained(\"runwayml/stable-diffusion-v1-5\")\n",
"# Move the model to the first TPU core\n",
"pipeline = pipeline.to(xm.xla_device())"
"pipeline = pipeline.to(torch_xla.device())"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions docs/source/learn/_pjrt.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ import torch_xla.distributed.xla_backend


def _mp_fn(index):
device = xm.xla_device()
device = torch_xla.device()
- dist.init_process_group('xla', rank=xr.global_ordinal(), world_size=xr.world_size())
+ dist.init_process_group('xla', init_method='xla://')

Expand Down Expand Up @@ -377,7 +377,7 @@ def _all_gather(index: int):
# No need to pass in `rank` or `world_size`
dist.init_process_group('xla', init_method='xla://')

t = torch.tensor([index], dtype=torch.int32, device=xm.xla_device())
t = torch.tensor([index], dtype=torch.int32, device=torch_xla.device())
output = [torch.zeros_like(t) for _ in range(dist.get_world_size())]
dist.all_gather(output, t)

Expand Down
22 changes: 11 additions & 11 deletions docs/source/learn/pytorch-on-xla-devices.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ import torch
import torch_xla
import torch_xla.core.xla_model as xm

t = torch.randn(2, 2, device=xm.xla_device())
t = torch.randn(2, 2, device=torch_xla.device())
print(t.device)
print(t)
```

This code should look familiar. PyTorch/XLA uses the same interface as
regular PyTorch with a few additions. Importing `torch_xla` initializes
PyTorch/XLA, and `xm.xla_device()` returns the current XLA device. This
PyTorch/XLA, and `torch_xla.device()` returns the current XLA device. This
may be a CPU or TPU depending on your environment.

## XLA Tensors are PyTorch Tensors
Expand All @@ -32,8 +32,8 @@ tensors.
For example, XLA tensors can be added together:

``` python
t0 = torch.randn(2, 2, device=xm.xla_device())
t1 = torch.randn(2, 2, device=xm.xla_device())
t0 = torch.randn(2, 2, device=torch_xla.device())
t1 = torch.randn(2, 2, device=torch_xla.device())
print(t0 + t1)
```

Expand All @@ -46,8 +46,8 @@ print(t0.mm(t1))
Or used with neural network modules:

``` python
l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20).to(xm.xla_device())
l_in = torch.randn(10, device=torch_xla.device())
linear = torch.nn.Linear(10, 20).to(torch_xla.device())
l_out = linear(l_in)
print(l_out)
```
Expand All @@ -56,7 +56,7 @@ Like other device types, XLA tensors only work with other XLA tensors on
the same device. So code like

``` python
l_in = torch.randn(10, device=xm.xla_device())
l_in = torch.randn(10, device=torch_xla.device())
linear = torch.nn.Linear(10, 20)
l_out = linear(l_in)
print(l_out)
Expand All @@ -79,7 +79,7 @@ The following snippet shows a network training on a single XLA device:
``` python
import torch_xla.core.xla_model as xm

device = xm.xla_device()
device = torch_xla.device()
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
Expand Down Expand Up @@ -116,7 +116,7 @@ import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl

def _mp_fn(index):
device = xm.xla_device()
device = torch_xla.device()
mp_device_loader = pl.MpDeviceLoader(train_loader, device)

model = MNIST().train().to(device)
Expand Down Expand Up @@ -144,7 +144,7 @@ previous single device snippet. Let's go over then one by one.
will only be able to access the device assigned to the current
process. For example on a TPU v4-8, there will be 4 processes
being spawn up and each process will own a TPU device.
- Note that if you print the `xm.xla_device()` on each process you
- Note that if you print the `torch_xla.device()` on each process you
will see `xla:0` on all devices. This is because each process
can only see one device. This does not mean multi-process is not
functioning. The only execution is with PJRT runtime on TPU v2
Expand Down Expand Up @@ -279,7 +279,7 @@ import torch
import torch_xla
import torch_xla.core.xla_model as xm

device = xm.xla_device()
device = torch_xla.device()

t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)
Expand Down
4 changes: 2 additions & 2 deletions docs/source/learn/troubleshoot.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ vm:~$ export PJRT_DEVICE=TPU
vm:~$ python3
>>> import torch
>>> import torch_xla.core.xla_model as xm
>>> t1 = torch.tensor(100, device=xm.xla_device())
>>> t2 = torch.tensor(200, device=xm.xla_device())
>>> t1 = torch.tensor(100, device=torch_xla.device())
>>> t2 = torch.tensor(200, device=torch_xla.device())
>>> print(t1 + t2)
tensor(300, device='xla:0')
```
Expand Down
8 changes: 4 additions & 4 deletions docs/source/learn/xla-overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ repo. contains examples for training and serving many LLM and diffusion models.

General guidelines to modify your code:

- Replace `cuda` with `xm.xla_device()`
- Replace `cuda` with `torch_xla.device()`
- Remove progress bar, printing that would access the XLA tensor
values
- Reduce logging and callbacks that would access the XLA tensor values
Expand Down Expand Up @@ -227,7 +227,7 @@ tutorial, but you can pass the `device` value to the function as well.

``` python
import torch_xla.core.xla_model as xm
self.device = xm.xla_device()
self.device = torch_xla.device()
```

Another place in the code that has cuda specific code is DDIM scheduler.
Expand All @@ -244,7 +244,7 @@ if attr.device != torch.device("cuda"):
with

``` python
device = xm.xla_device()
device = torch_xla.device()
attr = attr.to(torch.device(device))
```

Expand Down Expand Up @@ -339,7 +339,7 @@ with the following lines:

``` python
import torch_xla.core.xla_model as xm
device = xm.xla_device()
device = torch_xla.device()
pipe.to(device)
```

Expand Down
14 changes: 7 additions & 7 deletions docs/source/perf/amp.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ example is below:

``` python
# Creates model and optimizer in default precision
model = Net().to(xm.xla_device())
model = Net().to(torch_xla.device())
# Pytorch/XLA provides sync-free optimizers for improved performance
optimizer = syncfree.SGD(model.parameters(), ...)

for input, target in data:
optimizer.zero_grad()

# Enables autocasting for the forward pass
with autocast(xm.xla_device()):
with autocast(torch_xla.device()):
output = model(input)
loss = loss_fn(output, target)

Expand All @@ -33,7 +33,7 @@ for input, target in data:
xm.optimizer_step.(optimizer)
```

`autocast(xm.xla_device())` aliases `torch.autocast('xla')` when the XLA
`autocast(torch_xla.device())` aliases `torch.autocast('xla')` when the XLA
Device is a TPU. Alternatively, if a script is only used with TPUs, then
`torch.autocast('xla', dtype=torch.bfloat16)` can be directly used.

Expand Down Expand Up @@ -100,7 +100,7 @@ specific behavior. A simple CUDA AMP example is below:

``` python
# Creates model and optimizer in default precision
model = Net().to(xm.xla_device())
model = Net().to(torch_xla.device())
# Pytorch/XLA provides sync-free optimizers for improved performance
optimizer = syncfree.SGD(model.parameters(), ...)
scaler = GradScaler()
Expand All @@ -109,7 +109,7 @@ for input, target in data:
optimizer.zero_grad()

# Enables autocasting for the forward pass
with autocast(xm.xla_device()):
with autocast(torch_xla.device()):
output = model(input)
loss = loss_fn(output, target)

Expand All @@ -121,12 +121,12 @@ for input, target in data:
scaler.update()
```

`autocast(xm.xla_device())` aliases `torch.cuda.amp.autocast()` when the
`autocast(torch_xla.device())` aliases `torch.cuda.amp.autocast()` when the
XLA Device is a CUDA device (XLA:GPU). Alternatively, if a script is
only used with CUDA devices, then `torch.cuda.amp.autocast` can be
directly used, but requires `torch` is compiled with `cuda` support for
datatype of `torch.bfloat16`. We recommend using
`autocast(xm.xla_device())` on XLA:GPU as it does not require
`autocast(torch_xla.device())` on XLA:GPU as it does not require
`torch.cuda` support for any datatypes, including `torch.bfloat16`.

### AMP for XLA:GPU Best Practices
Expand Down
Loading
Loading