1- # Eager Mode + Compile API
1+ # PyTorch/XLA Compilation: Eager Mode vs. Compile API
22
3- In this doc we will go over how to use PyTorch/XLA's new experimental
4- ` eager ` mode with the ` compile ` API. The goal is to make PyTorch/XLA
5- experience more aligned with the native PyTorch and make development
6- process easier.
3+ ## Overview
74
8- Currently PyTorch/XLA runs on the LazyTensor tracing mode by default. In
9- the following code
5+ PyTorch/XLA integrates PyTorch with the XLA compiler to optimize deep learning
6+ workloads across various hardware accelerators. Currently PyTorch/XLA uses the
7+ LazyTensor tracing mode by default where operations are recorded into a
8+ computation graph for deferred compilation and execution (triggered by
9+ ` torch_xla.sync() ` ), as shown in the following code:
1010
11- ``` python
11+ ``` python
1212import torch
1313import torch_xla
1414import torchvision
@@ -24,24 +24,40 @@ res = model(input)
2424torch_xla.sync()
2525```
2626
27- The actual model compilation and device execution happens when
28- ` torch_xla.sync ` is called. There are multiple drawback of this
29- approach.
27+ While this approach enables performance optimizations, it introduces significant
28+ usability challenges.
3029
31- 1 . Users are often confused about when the framework is tracing and
32- when the framework is executing.
33- 2 . Non-core model code(data preprocessing for example) often generates
34- some small pending execution that gets leaked into the main
35- graph(step function) and causes recompilation. The recompilation of
36- the whole graph is usually very expensive.
37- 3 . It is hard to debug when/why recompilation happens.
30+ ## Challenges with LazyTensor Mode
3831
39- To mitigate above issues we want to introduce the new UX with eager and
40- compile .
32+ - ** Ambiguity ** : Developers struggle to distinguish between tracing and
33+ execution phases, complicating development and debugging .
4134
42- ## Basic Usage
35+ - ** Recompilation Overhead** : Non-core operations (e.g., data preprocessing) can
36+ leak into the graph, triggering expensive recompilations.
4337
44- ``` python
38+ - ** Debugging Difficulty** : Identifying the cause of recompilations is
39+ challenging due to the opaque nature of graph-building processes.
40+
41+ ## Eager Mode and ` torch_xla.compile `
42+
43+ To address these issues, PyTorch/XLA introduces an experimental eager mode
44+ (enabled via ` torch_xla.experimental.eager_mode(True) ` ) and the
45+ ` torch_xla.compile ` API. ** This shift aligns PyTorch/XLA more closely with
46+ native PyTorch, prioritizing developer experience while preserving
47+ performance** . Eager mode is likely to become the default in future releases.
48+
49+ - ** Eager Mode** : Executes operations immediately, enhancing flexibility and
50+ debugging but at a performance cost for core tasks.
51+
52+ - ** torch_xla.compile** : A decorator or wrapper that explicitly marks code
53+ (e.g., a model or function) for XLA compilation within an eager context,
54+ providing clear boundaries and immediate feedback.
55+
56+ ## How ` torch_xla.compile ` works
57+
58+ Let's have a look at a basic usage of ` torch_xla.compile ` :
59+
60+ ``` python
4561import torch
4662import torch_xla
4763import torchvision
@@ -60,31 +76,33 @@ input = torch.randn(64, 3, 224, 224).to(device)
6076res = compiled_model(input )
6177```
6278
63- Note that
79+ where the implementation of ` torch_xla.compile ` can be summarized as follows:
6480
65- 1 . Currently user has to manually enable the eager mode by
66- ` torch_xla.experimental.eager_mode(True) ` .
67- 2 . The region of the code that wants to be compiled should be wrapped
68- by ` torch_xla.compile ` .
81+ 1 . ** Disables Eager Mode** : Temporarily switches to tracing to build a
82+ computation graph.
6983
70- The implementation of the ` torch_xla.compile ` is actually pretty
71- straight forward, it disables the eager mode when entering the target
72- function and start tracing. It will call the ` torch_xla.sync() ` when
73- target function returns and reenable the eager mode. You can expect the
74- same perfomrance by using the ` eager ` + ` compile ` API compared to the
75- existing ` mark_step/sync ` approach.
84+ 2 . ** Traces Operations** : Records operations for XLA optimization.
7685
77- ### Inference
86+ 3 . ** Compiles and Executes** : Triggers compilation and execution via an
87+ internal ` torch_xla.sync() ` call.
7888
79- ``` python
80- torch_xla.experimental.eager_mode( True )
81- compiled_model = torch.compile(model, backend = " openxla " )
82- ```
89+ 4 . ** Re-enables Eager Mode ** : Resumes eager execution after compilation.
90+
91+ This "eager-to-lazy-to-eager" transition abstracts synchronization complexity,
92+ balancing flexibility and performance.
8393
84- It is recommened to use the ` torch.compile ` instead of
85- ` torch_xla.compile ` for inference to reduce the tracing overhad.
94+ ## ` torch_xla.compile ` vs. ` torch.compile `
8695
87- ### Training
96+ The PyTorch ecosystem offers multiple compilation APIs, and understanding their
97+ distinct roles, especially within PyTorch/XLA, is crucial for optimal
98+ performance and development.
99+
100+ - ` torch_xla.compile ` is optimized for PyTorch/XLA training workflows. Designed
101+ to work efficiently with the XLA backend for iterative training, it's the
102+ recommended API for compiling training loops due to its observed performance
103+ advantages. The best practice is to enclose the complete training step, e.g.
104+ forward pass, loss calculation, backward pass, and optimizer step, within a
105+ ` step_fn ` and then compiling this function.
88106
89107``` python
90108torch_xla.experimental.eager_mode(True )
@@ -100,33 +118,45 @@ def step_fn(model, data, target, loss_fn, optimizer):
100118step_fn = torch_xla.compile(step_fn)
101119```
102120
103- In training we asked user to refactor the ` step_fn ` out because it is
104- usually better to compile the model's forward, backward and optimizer
105- together. The long term goal is to also use ` torch.compile ` for training
106- but right now we recommend user to use ` torch_xla.compile ` (for
107- perfomrance reason).
121+ - ` torch.compile ` is PyTorch's general-purpose compilation API designed to
122+ accelerate PyTorch models across various backends. For PyTorch/XLA, it uses the
123+ ` openxla ` backend. We recommend ` torch.compile ` for PyTorch/XLA inference
124+ because it lowers tracing overhead, leading to more efficient static inference
125+ graphs. To use it with XLA, simply specify ` backend="openxla" ` .
126+
127+ ``` python
128+ torch_xla.experimental.eager_mode(True )
129+ compiled_model = torch.compile(model, backend = " openxla" )
130+ ```
131+
132+ The long-term aim is for ` torch.compile ` to be the single compilation API for
133+ both training and inference on XLA.
108134
109- ## Benchmark
135+ ## Performance Benchmarks
110136
111- I run a 2 layer decoder only model training(it is pretty much just a
112- llama2) with fake data on a single chip of v4-8 for 300 steps. Below is
113- the number I observed.
137+ To quantify the performance impact of torch_xla.compile and eager mode,
138+ benchmarks were conducted under specific conditions. The benchmarks utilized a
139+ 2-layer decoder-only model, similar to Llama2, trained with fake data. The
140+ training process spanned 300 steps on a single chip of a v4-8 TPU. The observed
141+ performance, measured in tokens per second, clearly illustrates the impact of
142+ different execution modes:
114143
115- Mode token/s
116- --------------------------- -- -------
117- Tracing mode (base line) 147
118- Eager mode 65
119- Eager + torch_xla compile 147
144+ | Mode | token/s |
145+ | ----------------------------- | --------- |
146+ | Tracing mode (base line) | 147 |
147+ | Eager mode | 65 |
148+ | Eager + torch_xla compile | 147 |
120149
121- : Eager mode benchmarks
150+ Eager mode with ` torch_xla.compile ` matches the performance of traditional
151+ LazyTensor tracing mode at ` 147 ` tokens/s, demonstrating a better user
152+ experience without performance loss.
122153
123- Eager mode can achieve ~ 45% performance of the fully compiled model for
124- the decoder only model. For more information, see
125- [ train_decoder_only_base.py] ( https://github.com/pytorch/xla/blob/master/examples/train_decoder_only_base.py )
154+ Pure eager mode's performance is model-dependent; it achieves ~ 45% of the fully
155+ compiled model's performance for decoder-only models. However, for ResNet50,
156+ pure eager mode was significantly slower (about 1% of compiled mode). For more
157+ information, see [ train_decoder_only_base.py] ( https://github.com/pytorch/xla/blob/master/examples/train_decoder_only_base.py )
126158and [ eager example] ( https://github.com/pytorch/xla/tree/master/examples/eager ) .
127- Note that perfomrane of the eager mode is very model dependent. When I
128- tried to run the resnet50, the eager mode perfomrance is \~ 1% of the
129- compiled mode. We don't exepct user to use eager mode to execute the
130- main training loop. Eager mode is meant to be used to handle non-core
131- part of the training/inference logic(Data preprocessing, random number
132- generations etc) or debug.
159+ This varying overhead means pure eager mode is not intended for main training or
160+ inference loops. Its utility lies in non-core tasks like data preprocessing,
161+ random number generation, custom utilities, or debugging, where immediate
162+ execution is prioritized over throughput.
0 commit comments