1
- # Eager Mode + Compile API
1
+ # PyTorch/XLA Compilation: Eager Mode vs. Compile API
2
2
3
- In this doc we will go over how to use PyTorch/XLA's new experimental
4
- ` eager ` mode with the ` compile ` API. The goal is to make PyTorch/XLA
5
- experience more aligned with the native PyTorch and make development
6
- process easier.
3
+ ## Overview
7
4
8
- Currently PyTorch/XLA runs on the LazyTensor tracing mode by default. In
9
- the following code
5
+ PyTorch/XLA integrates PyTorch with the XLA compiler to optimize deep learning
6
+ workloads across various hardware accelerators. Currently PyTorch/XLA uses the
7
+ LazyTensor tracing mode by default where operations are recorded into a
8
+ computation graph for deferred compilation and execution (triggered by
9
+ ` torch_xla.sync() ` ), as shown in the following code:
10
10
11
- ``` python
11
+ ``` python
12
12
import torch
13
13
import torch_xla
14
14
import torchvision
@@ -24,24 +24,40 @@ res = model(input)
24
24
torch_xla.sync()
25
25
```
26
26
27
- The actual model compilation and device execution happens when
28
- ` torch_xla.sync ` is called. There are multiple drawback of this
29
- approach.
27
+ While this approach enables performance optimizations, it introduces significant
28
+ usability challenges.
30
29
31
- 1 . Users are often confused about when the framework is tracing and
32
- when the framework is executing.
33
- 2 . Non-core model code(data preprocessing for example) often generates
34
- some small pending execution that gets leaked into the main
35
- graph(step function) and causes recompilation. The recompilation of
36
- the whole graph is usually very expensive.
37
- 3 . It is hard to debug when/why recompilation happens.
30
+ ## Challenges with LazyTensor Mode
38
31
39
- To mitigate above issues we want to introduce the new UX with eager and
40
- compile .
32
+ - ** Ambiguity ** : Developers struggle to distinguish between tracing and
33
+ execution phases, complicating development and debugging .
41
34
42
- ## Basic Usage
35
+ - ** Recompilation Overhead** : Non-core operations (e.g., data preprocessing) can
36
+ leak into the graph, triggering expensive recompilations.
43
37
44
- ``` python
38
+ - ** Debugging Difficulty** : Identifying the cause of recompilations is
39
+ challenging due to the opaque nature of graph-building processes.
40
+
41
+ ## Eager Mode and ` torch_xla.compile `
42
+
43
+ To address these issues, PyTorch/XLA introduces an experimental eager mode
44
+ (enabled via ` torch_xla.experimental.eager_mode(True) ` ) and the
45
+ ` torch_xla.compile ` API. ** This shift aligns PyTorch/XLA more closely with
46
+ native PyTorch, prioritizing developer experience while preserving
47
+ performance** . Eager mode is likely to become the default in future releases.
48
+
49
+ - ** Eager Mode** : Executes operations immediately, enhancing flexibility and
50
+ debugging but at a performance cost for core tasks.
51
+
52
+ - ** torch_xla.compile** : A decorator or wrapper that explicitly marks code
53
+ (e.g., a model or function) for XLA compilation within an eager context,
54
+ providing clear boundaries and immediate feedback.
55
+
56
+ ## How ` torch_xla.compile ` works
57
+
58
+ Let's have a look at a basic usage of ` torch_xla.compile ` :
59
+
60
+ ``` python
45
61
import torch
46
62
import torch_xla
47
63
import torchvision
@@ -60,31 +76,33 @@ input = torch.randn(64, 3, 224, 224).to(device)
60
76
res = compiled_model(input )
61
77
```
62
78
63
- Note that
79
+ where the implementation of ` torch_xla.compile ` can be summarized as follows:
64
80
65
- 1 . Currently user has to manually enable the eager mode by
66
- ` torch_xla.experimental.eager_mode(True) ` .
67
- 2 . The region of the code that wants to be compiled should be wrapped
68
- by ` torch_xla.compile ` .
81
+ 1 . ** Disables Eager Mode** : Temporarily switches to tracing to build a
82
+ computation graph.
69
83
70
- The implementation of the ` torch_xla.compile ` is actually pretty
71
- straight forward, it disables the eager mode when entering the target
72
- function and start tracing. It will call the ` torch_xla.sync() ` when
73
- target function returns and reenable the eager mode. You can expect the
74
- same perfomrance by using the ` eager ` + ` compile ` API compared to the
75
- existing ` mark_step/sync ` approach.
84
+ 2 . ** Traces Operations** : Records operations for XLA optimization.
76
85
77
- ### Inference
86
+ 3 . ** Compiles and Executes** : Triggers compilation and execution via an
87
+ internal ` torch_xla.sync() ` call.
78
88
79
- ``` python
80
- torch_xla.experimental.eager_mode( True )
81
- compiled_model = torch.compile(model, backend = " openxla " )
82
- ```
89
+ 4 . ** Re-enables Eager Mode ** : Resumes eager execution after compilation.
90
+
91
+ This "eager-to-lazy-to-eager" transition abstracts synchronization complexity,
92
+ balancing flexibility and performance.
83
93
84
- It is recommened to use the ` torch.compile ` instead of
85
- ` torch_xla.compile ` for inference to reduce the tracing overhad.
94
+ ## ` torch_xla.compile ` vs. ` torch.compile `
86
95
87
- ### Training
96
+ The PyTorch ecosystem offers multiple compilation APIs, and understanding their
97
+ distinct roles, especially within PyTorch/XLA, is crucial for optimal
98
+ performance and development.
99
+
100
+ - ` torch_xla.compile ` is optimized for PyTorch/XLA training workflows. Designed
101
+ to work efficiently with the XLA backend for iterative training, it's the
102
+ recommended API for compiling training loops due to its observed performance
103
+ advantages. The best practice is to enclose the complete training step, e.g.
104
+ forward pass, loss calculation, backward pass, and optimizer step, within a
105
+ ` step_fn ` and then compiling this function.
88
106
89
107
``` python
90
108
torch_xla.experimental.eager_mode(True )
@@ -100,33 +118,45 @@ def step_fn(model, data, target, loss_fn, optimizer):
100
118
step_fn = torch_xla.compile(step_fn)
101
119
```
102
120
103
- In training we asked user to refactor the ` step_fn ` out because it is
104
- usually better to compile the model's forward, backward and optimizer
105
- together. The long term goal is to also use ` torch.compile ` for training
106
- but right now we recommend user to use ` torch_xla.compile ` (for
107
- perfomrance reason).
121
+ - ` torch.compile ` is PyTorch's general-purpose compilation API designed to
122
+ accelerate PyTorch models across various backends. For PyTorch/XLA, it uses the
123
+ ` openxla ` backend. We recommend ` torch.compile ` for PyTorch/XLA inference
124
+ because it lowers tracing overhead, leading to more efficient static inference
125
+ graphs. To use it with XLA, simply specify ` backend="openxla" ` .
126
+
127
+ ``` python
128
+ torch_xla.experimental.eager_mode(True )
129
+ compiled_model = torch.compile(model, backend = " openxla" )
130
+ ```
131
+
132
+ The long-term aim is for ` torch.compile ` to be the single compilation API for
133
+ both training and inference on XLA.
108
134
109
- ## Benchmark
135
+ ## Performance Benchmarks
110
136
111
- I run a 2 layer decoder only model training(it is pretty much just a
112
- llama2) with fake data on a single chip of v4-8 for 300 steps. Below is
113
- the number I observed.
137
+ To quantify the performance impact of torch_xla.compile and eager mode,
138
+ benchmarks were conducted under specific conditions. The benchmarks utilized a
139
+ 2-layer decoder-only model, similar to Llama2, trained with fake data. The
140
+ training process spanned 300 steps on a single chip of a v4-8 TPU. The observed
141
+ performance, measured in tokens per second, clearly illustrates the impact of
142
+ different execution modes:
114
143
115
- Mode token/s
116
- --------------------------- -- -------
117
- Tracing mode (base line) 147
118
- Eager mode 65
119
- Eager + torch_xla compile 147
144
+ | Mode | token/s |
145
+ | ----------------------------- | --------- |
146
+ | Tracing mode (base line) | 147 |
147
+ | Eager mode | 65 |
148
+ | Eager + torch_xla compile | 147 |
120
149
121
- : Eager mode benchmarks
150
+ Eager mode with ` torch_xla.compile ` matches the performance of traditional
151
+ LazyTensor tracing mode at ` 147 ` tokens/s, demonstrating a better user
152
+ experience without performance loss.
122
153
123
- Eager mode can achieve ~ 45% performance of the fully compiled model for
124
- the decoder only model. For more information, see
125
- [ train_decoder_only_base.py] ( https://github.com/pytorch/xla/blob/master/examples/train_decoder_only_base.py )
154
+ Pure eager mode's performance is model-dependent; it achieves ~ 45% of the fully
155
+ compiled model's performance for decoder-only models. However, for ResNet50,
156
+ pure eager mode was significantly slower (about 1% of compiled mode). For more
157
+ information, see [ train_decoder_only_base.py] ( https://github.com/pytorch/xla/blob/master/examples/train_decoder_only_base.py )
126
158
and [ eager example] ( https://github.com/pytorch/xla/tree/master/examples/eager ) .
127
- Note that perfomrane of the eager mode is very model dependent. When I
128
- tried to run the resnet50, the eager mode perfomrance is \~ 1% of the
129
- compiled mode. We don't exepct user to use eager mode to execute the
130
- main training loop. Eager mode is meant to be used to handle non-core
131
- part of the training/inference logic(Data preprocessing, random number
132
- generations etc) or debug.
159
+ This varying overhead means pure eager mode is not intended for main training or
160
+ inference loops. Its utility lies in non-core tasks like data preprocessing,
161
+ random number generation, custom utilities, or debugging, where immediate
162
+ execution is prioritized over throughput.
0 commit comments