Skip to content

Commit 7ce8493

Browse files
Fix the examples in API_GUIDE (#9213)
Co-authored-by: Zhanyong Wan <[email protected]>
1 parent 334b4d3 commit 7ce8493

File tree

1 file changed

+69
-13
lines changed

1 file changed

+69
-13
lines changed

API_GUIDE.md

Lines changed: 69 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,45 @@ multi-processing.
7777
The following snippet shows a network training on a single XLA device:
7878

7979
```python
80-
import torch_xla
8180
import torch_xla.core.xla_model as xm
82-
83-
device = xm.xla_device()
84-
model = MNIST().train().to(device)
81+
from torch_xla import runtime as xr
82+
import torch
83+
import torch_xla.utils.utils as xu
84+
import torch.nn as nn
85+
import torch.optim as optim
86+
import torch.nn.functional as F
87+
88+
class MNIST(nn.Module):
89+
def __init__(self):
90+
super(MNIST, self).__init__()
91+
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
92+
self.bn1 = nn.BatchNorm2d(10)
93+
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
94+
self.bn2 = nn.BatchNorm2d(20)
95+
self.fc1 = nn.Linear(320, 50)
96+
self.fc2 = nn.Linear(50, 10)
97+
98+
def forward(self, x):
99+
x = F.relu(F.max_pool2d(self.conv1(x), 2))
100+
x = self.bn1(x)
101+
x = F.relu(F.max_pool2d(self.conv2(x), 2))
102+
x = self.bn2(x)
103+
x = torch.flatten(x, 1)
104+
x = F.relu(self.fc1(x))
105+
x = self.fc2(x)
106+
return F.log_softmax(x, dim=1)
107+
108+
# Create a synthetic dataset.
109+
batch_size = 128
110+
train_loader = xu.SampleGenerator(
111+
data=(torch.zeros(batch_size, 1, 28, 28),
112+
torch.zeros(batch_size, dtype=torch.int64)),
113+
sample_count=60000 // batch_size // xr.world_size())
114+
115+
device = xm.xla_device() # Get the XLA device (TPU).
116+
model = MNIST().train().to(device) # Create a model and move it to the device.
85117
loss_fn = nn.NLLLoss()
86-
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
118+
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
87119

88120
for data, target in train_loader:
89121
optimizer.zero_grad()
@@ -92,16 +124,16 @@ for data, target in train_loader:
92124
output = model(data)
93125
loss = loss_fn(output, target)
94126
loss.backward()
95-
96127
optimizer.step()
97-
torch_xla.sync()
128+
# Mark the end of a training step and trigger the exeuction of the accumulated
129+
# operations on the TPU.
130+
xm.mark_step()
98131
```
99132

100133
This snippet highlights how easy it is to switch your model to run on XLA. The
101134
model definition, dataloader, optimizer and training loop can work on any device.
102135
The only XLA-specific code is a couple lines that acquire the XLA device and
103-
materializing the tensors. Calling
104-
`torch_xla.sync()` at the end of each training
136+
materializing the tensors. Calling `xm.mark_step()` at the end of each training
105137
iteration causes XLA to execute its current graph and update the model's
106138
parameters. See [XLA Tensor Deep Dive](#xla-tensor-deep-dive) for more on
107139
how XLA creates graphs and runs operations.
@@ -112,26 +144,50 @@ PyTorch/XLA makes it easy to accelerate training by running on multiple XLA
112144
devices. The following snippet shows how:
113145

114146
```python
115-
import torch_xla
116147
import torch_xla.core.xla_model as xm
148+
from torch_xla import runtime as xr
149+
import torch
150+
import torch_xla.utils.utils as xu
151+
import torch.nn as nn
152+
import torch.optim as optim
153+
import torch.nn.functional as F
154+
import torch_xla
117155
import torch_xla.distributed.parallel_loader as pl
118156

157+
class MNIST(nn.Module):
158+
# The same as in the previous example.
159+
...
160+
161+
batch_size=128
162+
# The same as in the previous example.
163+
train_loader = ...
164+
119165
def _mp_fn(index):
120-
device = xm.xla_device()
166+
"""Called on each process/device.
167+
168+
Args:
169+
index: Index of the process.
170+
"""
171+
172+
device = xm.xla_device() # Get the device assigned to this process.
173+
# Wrap the loader for multi-device.
121174
mp_device_loader = pl.MpDeviceLoader(train_loader, device)
122175

123176
model = MNIST().train().to(device)
124177
loss_fn = nn.NLLLoss()
125-
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
178+
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
126179

127180
for data, target in mp_device_loader:
128181
optimizer.zero_grad()
129182
output = model(data)
130183
loss = loss_fn(output, target)
131184
loss.backward()
185+
# Perform the optimization step and trigger the execution of the
186+
# accumulated XLA operations on the device for this process.
132187
xm.optimizer_step(optimizer)
133188

134189
if __name__ == '__main__':
190+
# Launch the multi-device training.
135191
torch_xla.launch(_mp_fn, args=())
136192
```
137193

@@ -141,7 +197,7 @@ single device snippet. Let's go over then one by one.
141197
- `torch_xla.launch()`
142198
- Creates the processes that each run an XLA device.
143199
- This function is a wrapper of multithreading spawn to allow user run the script with torchrun command line also. Each process will only be able to access the device assigned to the current process. For example on a TPU v4-8, there will be 4 processes being spawn up and each process will own a TPU device.
144-
- Note that if you print the `xm.xla_device()` on each process you will see `xla:0` on all devices. This is because each process can only see one device. This does not mean multi-process is not functioning. The only execution is with PJRT runtime on TPU v2 and TPU v3 since there will be `#devices/2` processes and each process will have 2 threads(check this [doc](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#tpus-v2v3-vs-v4) for more details).
200+
- Note that if you print the `xm.xla_device()` on each process you will see `xla:0` on all devices. This is because each process can only see one device. This does not mean multi-process is not functioning. The only exeption is with PJRT runtime on TPU v2 and TPU v3 since there will be `#devices/2` processes and each process will have 2 threads (check this [doc](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#tpus-v2v3-vs-v4) for more details).
145201
- `MpDeviceLoader`
146202
- Loads the training data onto each device.
147203
- `MpDeviceLoader` can wrap on a torch dataloader. It can preload the data to the device and overlap the dataloading with device execution to improve the performance.

0 commit comments

Comments
 (0)