Skip to content

Commit dfe31ee

Browse files
committed
Improve torch_xla.compile documentation
1 parent edc1a88 commit dfe31ee

File tree

1 file changed

+95
-65
lines changed

1 file changed

+95
-65
lines changed

docs/source/learn/eager.md

Lines changed: 95 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
# Eager Mode + Compile API
1+
# PyTorch/XLA Compilation: Eager Mode vs. Compile API
22

3-
In this doc we will go over how to use PyTorch/XLA's new experimental
4-
`eager` mode with the `compile` API. The goal is to make PyTorch/XLA
5-
experience more aligned with the native PyTorch and make development
6-
process easier.
3+
## Overview
74

8-
Currently PyTorch/XLA runs on the LazyTensor tracing mode by default. In
9-
the following code
5+
PyTorch/XLA integrates PyTorch with the XLA compiler to optimize deep learning
6+
workloads across various hardware accelerators. Currently PyTorch/XLA uses the
7+
LazyTensor tracing mode by default where operations are recorded into a
8+
computation graph for deferred compilation and execution (triggered by
9+
`torch_xla.sync()`), as shown in the following code:
1010

11-
``` python
11+
```python
1212
import torch
1313
import torch_xla
1414
import torchvision
@@ -24,24 +24,40 @@ res = model(input)
2424
torch_xla.sync()
2525
```
2626

27-
The actual model compilation and device execution happens when
28-
`torch_xla.sync` is called. There are multiple drawback of this
29-
approach.
27+
While this approach enables performance optimizations, it introduces significant
28+
usability challenges.
3029

31-
1. Users are often confused about when the framework is tracing and
32-
when the framework is executing.
33-
2. Non-core model code(data preprocessing for example) often generates
34-
some small pending execution that gets leaked into the main
35-
graph(step function) and causes recompilation. The recompilation of
36-
the whole graph is usually very expensive.
37-
3. It is hard to debug when/why recompilation happens.
30+
## Challenges with LazyTensor Mode
3831

39-
To mitigate above issues we want to introduce the new UX with eager and
40-
compile.
32+
- **Ambiguity**: Developers struggle to distinguish between tracing and
33+
execution phases, complicating development and debugging.
4134

42-
## Basic Usage
35+
- **Recompilation Overhead**: Non-core operations (e.g., data preprocessing) can
36+
leak into the graph, triggering expensive recompilations.
4337

44-
``` python
38+
- **Debugging Difficulty**: Identifying the cause of recompilations is
39+
challenging due to the opaque nature of graph-building processes.
40+
41+
## Eager Mode and `torch_xla.compile`
42+
43+
To address these issues, PyTorch/XLA introduces an experimental eager mode
44+
(enabled via `torch_xla.experimental.eager_mode(True)`) and the
45+
`torch_xla.compile` API. **This shift aligns PyTorch/XLA more closely with
46+
native PyTorch, prioritizing developer experience while preserving
47+
performance**. Eager mode is likely to become the default in future releases.
48+
49+
- **Eager Mode**: Executes operations immediately, enhancing flexibility and
50+
debugging but at a performance cost for core tasks.
51+
52+
- **torch_xla.compile**: A decorator or wrapper that explicitly marks code
53+
(e.g., a model or function) for XLA compilation within an eager context,
54+
providing clear boundaries and immediate feedback.
55+
56+
## How `torch_xla.compile` works
57+
58+
Let's have a look at a basic usage of `torch_xla.compile`:
59+
60+
```python
4561
import torch
4662
import torch_xla
4763
import torchvision
@@ -60,31 +76,33 @@ input = torch.randn(64, 3, 224, 224).to(device)
6076
res = compiled_model(input)
6177
```
6278

63-
Note that
79+
where the implementation of `torch_xla.compile` can be summarized as follows:
6480

65-
1. Currently user has to manually enable the eager mode by
66-
`torch_xla.experimental.eager_mode(True)`.
67-
2. The region of the code that wants to be compiled should be wrapped
68-
by `torch_xla.compile`.
81+
1. **Disables Eager Mode**: Temporarily switches to tracing to build a
82+
computation graph.
6983

70-
The implementation of the `torch_xla.compile` is actually pretty
71-
straight forward, it disables the eager mode when entering the target
72-
function and start tracing. It will call the `torch_xla.sync()` when
73-
target function returns and reenable the eager mode. You can expect the
74-
same perfomrance by using the `eager` + `compile` API compared to the
75-
existing `mark_step/sync` approach.
84+
2. **Traces Operations**: Records operations for XLA optimization.
7685

77-
### Inference
86+
3. **Compiles and Executes**: Triggers compilation and execution via an
87+
internal `torch_xla.sync()` call.
7888

79-
``` python
80-
torch_xla.experimental.eager_mode(True)
81-
compiled_model = torch.compile(model, backend="openxla")
82-
```
89+
4. **Re-enables Eager Mode**: Resumes eager execution after compilation.
90+
91+
This "eager-to-lazy-to-eager" transition abstracts synchronization complexity,
92+
balancing flexibility and performance.
8393

84-
It is recommened to use the `torch.compile` instead of
85-
`torch_xla.compile` for inference to reduce the tracing overhad.
94+
## `torch_xla.compile` vs. `torch.compile`
8695

87-
### Training
96+
The PyTorch ecosystem offers multiple compilation APIs, and understanding their
97+
distinct roles, especially within PyTorch/XLA, is crucial for optimal
98+
performance and development.
99+
100+
- `torch_xla.compile` is optimized for PyTorch/XLA training workflows. Designed
101+
to work efficiently with the XLA backend for iterative training, it's the
102+
recommended API for compiling training loops due to its observed performance
103+
advantages. The best practice is to enclose the complete training step, e.g.
104+
forward pass, loss calculation, backward pass, and optimizer step, within a
105+
`step_fn` and then compiling this function.
88106

89107
``` python
90108
torch_xla.experimental.eager_mode(True)
@@ -100,33 +118,45 @@ def step_fn(model, data, target, loss_fn, optimizer):
100118
step_fn = torch_xla.compile(step_fn)
101119
```
102120

103-
In training we asked user to refactor the `step_fn` out because it is
104-
usually better to compile the model's forward, backward and optimizer
105-
together. The long term goal is to also use `torch.compile` for training
106-
but right now we recommend user to use `torch_xla.compile`(for
107-
perfomrance reason).
121+
- `torch.compile` is PyTorch's general-purpose compilation API designed to
122+
accelerate PyTorch models across various backends. For PyTorch/XLA, it uses the
123+
`openxla` backend. We recommend `torch.compile` for PyTorch/XLA inference
124+
because it lowers tracing overhead, leading to more efficient static inference
125+
graphs. To use it with XLA, simply specify `backend="openxla"`.
126+
127+
``` python
128+
torch_xla.experimental.eager_mode(True)
129+
compiled_model = torch.compile(model, backend="openxla")
130+
```
131+
132+
The long-term aim is for `torch.compile` to be the single compilation API for
133+
both training and inference on XLA.
108134

109-
## Benchmark
135+
## Performance Benchmarks
110136

111-
I run a 2 layer decoder only model training(it is pretty much just a
112-
llama2) with fake data on a single chip of v4-8 for 300 steps. Below is
113-
the number I observed.
137+
To quantify the performance impact of torch_xla.compile and eager mode,
138+
benchmarks were conducted under specific conditions. The benchmarks utilized a
139+
2-layer decoder-only model, similar to Llama2, trained with fake data. The
140+
training process spanned 300 steps on a single chip of a v4-8 TPU. The observed
141+
performance, measured in tokens per second, clearly illustrates the impact of
142+
different execution modes:
114143

115-
Mode token/s
116-
--------------------------- ---------
117-
Tracing mode (base line) 147
118-
Eager mode 65
119-
Eager + torch_xla compile 147
144+
| Mode | token/s |
145+
|-----------------------------|---------|
146+
| Tracing mode (base line) | 147 |
147+
| Eager mode | 65 |
148+
| Eager + torch_xla compile | 147 |
120149

121-
: Eager mode benchmarks
150+
Eager mode with `torch_xla.compile` matches the performance of traditional
151+
LazyTensor tracing mode at `147` tokens/s, demonstrating a better user
152+
experience without performance loss.
122153

123-
Eager mode can achieve ~45% performance of the fully compiled model for
124-
the decoder only model. For more information, see
125-
[train_decoder_only_base.py](https://github.com/pytorch/xla/blob/master/examples/train_decoder_only_base.py)
154+
Pure eager mode's performance is model-dependent; it achieves ~45% of the fully
155+
compiled model's performance for decoder-only models. However, for ResNet50,
156+
pure eager mode was significantly slower (about 1% of compiled mode). For more
157+
information, see [train_decoder_only_base.py](https://github.com/pytorch/xla/blob/master/examples/train_decoder_only_base.py)
126158
and [eager example](https://github.com/pytorch/xla/tree/master/examples/eager).
127-
Note that perfomrane of the eager mode is very model dependent. When I
128-
tried to run the resnet50, the eager mode perfomrance is \~1% of the
129-
compiled mode. We don't exepct user to use eager mode to execute the
130-
main training loop. Eager mode is meant to be used to handle non-core
131-
part of the training/inference logic(Data preprocessing, random number
132-
generations etc) or debug.
159+
This varying overhead means pure eager mode is not intended for main training or
160+
inference loops. Its utility lies in non-core tasks like data preprocessing,
161+
random number generation, custom utilities, or debugging, where immediate
162+
execution is prioritized over throughput.

0 commit comments

Comments
 (0)