1
- # Eager Mode + Compile API
1
+ # PyTorch/XLA Compilation: Eager Mode vs. Compile API
2
2
3
- In this doc we will go over how to use PyTorch/XLA's new experimental
4
- ` eager ` mode with the ` compile ` API. The goal is to make PyTorch/XLA
5
- experience more aligned with the native PyTorch and make development
6
- process easier.
3
+ ## Overview
7
4
8
- Currently PyTorch/XLA runs on the LazyTensor tracing mode by default. In
9
- the following code
5
+ PyTorch/XLA integrates PyTorch with the XLA compiler to optimize deep learning
6
+ workloads across various hardware accelerators. Currently PyTorch/XLA runs on
7
+ the LazyTensor tracing mode by default where operations are recorded into a
8
+ computation graph for deferred compilation and execution (triggered by
9
+ ` torch_xla.sync() ` ), as exemplified in the following code
10
10
11
- ``` python
11
+ ``` python
12
12
import torch
13
13
import torch_xla
14
14
import torchvision
@@ -24,24 +24,40 @@ res = model(input)
24
24
torch_xla.sync()
25
25
```
26
26
27
- The actual model compilation and device execution happens when
28
- ` torch_xla.sync ` is called. There are multiple drawback of this
29
- approach.
27
+ While this approach enabled performance optimizations, it introduced significant
28
+ usability challenges.
30
29
31
- 1 . Users are often confused about when the framework is tracing and
32
- when the framework is executing.
33
- 2 . Non-core model code(data preprocessing for example) often generates
34
- some small pending execution that gets leaked into the main
35
- graph(step function) and causes recompilation. The recompilation of
36
- the whole graph is usually very expensive.
37
- 3 . It is hard to debug when/why recompilation happens.
30
+ ## Challenges with LazyTensor Mode
38
31
39
- To mitigate above issues we want to introduce the new UX with eager and
40
- compile .
32
+ - ** Ambiguity ** : Developers struggled to distinguish between tracing and
33
+ execution phases, complicating development and debugging .
41
34
42
- ## Basic Usage
35
+ - ** Recompilation Overhead** : Non-core operations (e.g., data preprocessing)
36
+ could leak into the graph, triggering expensive recompilations.
43
37
44
- ``` python
38
+ - ** Debugging Difficulty** : Identifying the causes of recompilations are
39
+ challenging due to the opaque nature of graph-building processes.
40
+
41
+ ## Eager Mode and ` torch_xla.compile `
42
+
43
+ To address these issues, PyTorch/XLA introduces an experimental eager mode
44
+ (enabled via ` torch_xla.experimental.eager_mode(True) ` ) and the
45
+ ` torch_xla.compile ` API. ** This shift aligns PyTorch/XLA more closely with
46
+ native PyTorch, prioritizing developer experience while preserving
47
+ performance** . Eager mode is likely to become the default in future releases.
48
+
49
+ - ** Eager Mode** : Executes operations immediately, enhancing flexibility and
50
+ debugging but at a performance cost for core tasks.
51
+
52
+ - ** torch_xla.compile** : A decorator or wrapper that explicitly marks code
53
+ (e.g., a model or function) for XLA compilation within an eager context,
54
+ providing clear boundaries and immediate feedback.
55
+
56
+ ## How ` torch_xla.compile ` works
57
+
58
+ Let's have a look at a basic usage of ` torch_xla.compile ` :
59
+
60
+ ``` python
45
61
import torch
46
62
import torch_xla
47
63
import torchvision
@@ -60,31 +76,33 @@ input = torch.randn(64, 3, 224, 224).to(device)
60
76
res = compiled_model(input )
61
77
```
62
78
63
- Note that
79
+ where the implementation of ` torch_xla.compile ` can be summarized as follow:
64
80
65
- 1 . Currently user has to manually enable the eager mode by
66
- ` torch_xla.experimental.eager_mode(True) ` .
67
- 2 . The region of the code that wants to be compiled should be wrapped
68
- by ` torch_xla.compile ` .
81
+ 1 . ** Disables Eager Mode** : Temporarily switches to tracing to build a
82
+ computation graph.
69
83
70
- The implementation of the ` torch_xla.compile ` is actually pretty
71
- straight forward, it disables the eager mode when entering the target
72
- function and start tracing. It will call the ` torch_xla.sync() ` when
73
- target function returns and reenable the eager mode. You can expect the
74
- same perfomrance by using the ` eager ` + ` compile ` API compared to the
75
- existing ` mark_step/sync ` approach.
84
+ 2 . ** Traces Operations** : Records operations for XLA optimization.
76
85
77
- ### Inference
86
+ 3 . ** Compiles and Executes** : Triggers compilation and execution via an
87
+ internal ` torch_xla.sync() ` call.
78
88
79
- ``` python
80
- torch_xla.experimental.eager_mode( True )
81
- compiled_model = torch.compile(model, backend = " openxla " )
82
- ```
89
+ 4 . ** Re-enables Eager Mode ** : Resumes eager execution after compilation.
90
+
91
+ This "eager-to-lazy-to-eager" transition abstracts synchronization complexity,
92
+ balancing flexibility and performance.
83
93
84
- It is recommened to use the ` torch.compile ` instead of
85
- ` torch_xla.compile ` for inference to reduce the tracing overhad.
94
+ ## ` torch_xla.compile ` vs. ` torch.compile `
86
95
87
- ### Training
96
+ The PyTorch ecosystem offers multiple compilation APIs, and understanding their
97
+ distinct roles, especially within PyTorch/XLA, is crucial for optimal
98
+ performance and development.
99
+
100
+ - ` torch_xla.compile ` is optimized for PyTorch/XLA training workflows. Designed
101
+ to work efficiently with the XLA backend for iterative training, it's the
102
+ recommended API for compiling training loops due to its observed performance
103
+ advantages. Best practice dictates enclosing the complete training step—forward
104
+ pass, loss calculation, backward pass, and optimizer step—within a ` step_fn ` and
105
+ then compiling this function.
88
106
89
107
``` python
90
108
torch_xla.experimental.eager_mode(True )
@@ -100,33 +118,53 @@ def step_fn(model, data, target, loss_fn, optimizer):
100
118
step_fn = torch_xla.compile(step_fn)
101
119
```
102
120
103
- In training we asked user to refactor the ` step_fn ` out because it is
104
- usually better to compile the model's forward, backward and optimizer
105
- together. The long term goal is to also use ` torch.compile ` for training
106
- but right now we recommend user to use ` torch_xla.compile ` (for
107
- perfomrance reason).
108
-
109
- ## Benchmark
110
-
111
- I run a 2 layer decoder only model training(it is pretty much just a
112
- llama2) with fake data on a single chip of v4-8 for 300 steps. Below is
113
- the number I observed.
121
+ - ` torch.compile ` is PyTorch's general-purpose compilation API designed to
122
+ accelerate PyTorch models across various backends. For PyTorch/XLA, it uses the
123
+ ` openxla ` backend. We recommend ` torch.compile ` for PyTorch/XLA inference
124
+ because it lowers tracing overhead, leading to more efficient static inference
125
+ graphs. To use it with XLA, simply specify ` backend="openxla" ` .
114
126
115
- Mode token/s
116
- --------------------------- ---------
117
- Tracing mode (base line) 147
118
- Eager mode 65
119
- Eager + torch_xla compile 147
120
-
121
- : Eager mode benchmarks
127
+ ``` python
128
+ torch_xla.experimental.eager_mode(True )
129
+ compiled_model = torch.compile(model, backend = " openxla" )
130
+ ```
122
131
123
- Eager mode can achieve ~ 45% performance of the fully compiled model for
124
- the decoder only model. For more information, see
132
+ The long-term aim is for ` torch.compile ` to be the single compilation API for
133
+ both training and inference on XLA.
134
+
135
+ ## Performance Benchmarks
136
+
137
+ To quantify the performance impact of torch_xla.compile and eager mode,
138
+ benchmarks were conducted under specific conditions. The benchmarks utilized a
139
+ 2-layer decoder-only model, similar to Llama2, trained with fake data. The
140
+ training process spanned 300 steps on a single chip of a v4-8 TPU. The observed
141
+ performance, measured in tokens per second, clearly illustrates the impact of
142
+ different execution modes:
143
+
144
+ | Mode | token/s |
145
+ | -----------------------------| ---------|
146
+ | Tracing mode (base line) | 147 |
147
+ | Eager mode | 65 |
148
+ | Eager + torch_xla compile | 147 |
149
+
150
+ The benchmark results unequivocally demonstrate that eager mode combined with
151
+ ` torch_xla.compile ` achieves performance parity with the traditional LazyTensor
152
+ tracing mode, both yielding ` 147 ` tokens/s. This empirically validates the claim
153
+ that the new API provides a better user experience without a performance penalty
154
+ for compiled regions, making it a "no-regret" upgrade for performance-conscious
155
+ users.
156
+
157
+ Eager mode can achieve \~ 45% performance of the fully compiled model for the
158
+ decoder only model. For more information, see
125
159
[ train_decoder_only_base.py] ( https://github.com/pytorch/xla/blob/master/examples/train_decoder_only_base.py )
126
160
and [ eager example] ( https://github.com/pytorch/xla/tree/master/examples/eager ) .
127
- Note that perfomrane of the eager mode is very model dependent. When I
128
- tried to run the resnet50, the eager mode perfomrance is \~ 1% of the
129
- compiled mode. We don't exepct user to use eager mode to execute the
130
- main training loop. Eager mode is meant to be used to handle non-core
131
- part of the training/inference logic(Data preprocessing, random number
132
- generations etc) or debug.
161
+ It is crucial to understand that the performance of pure eager mode is highly
162
+ model-dependent. The benchmarks indicate that for ResNet50, eager mode's
163
+ performance was drastically lower, achieving only about 1% of the compiled
164
+ mode's performance. This highlights that the overhead of eager execution can
165
+ vary wildly depending on the model's architecture. Given this significant
166
+ performance characteristic, pure eager mode is explicitly not intended for
167
+ executing the main training or inference loops. Its value lies in handling
168
+ non-core parts of the logic, such as data preprocessing, random number
169
+ generation, or custom utility functions, or for debugging purposes, where
170
+ immediate execution and inspectability are prioritized over raw throughput.
0 commit comments