-
Notifications
You must be signed in to change notification settings - Fork 527
Improve torch_xla.compile
documentation
#9194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,14 @@ | ||
# Eager Mode + Compile API | ||
# PyTorch/XLA Compilation: Eager Mode vs. Compile API | ||
|
||
In this doc we will go over how to use PyTorch/XLA's new experimental | ||
`eager` mode with the `compile` API. The goal is to make PyTorch/XLA | ||
experience more aligned with the native PyTorch and make development | ||
process easier. | ||
## Overview | ||
|
||
Currently PyTorch/XLA runs on the LazyTensor tracing mode by default. In | ||
the following code | ||
PyTorch/XLA integrates PyTorch with the XLA compiler to optimize deep learning | ||
workloads across various hardware accelerators. Currently PyTorch/XLA uses the | ||
LazyTensor tracing mode by default where operations are recorded into a | ||
computation graph for deferred compilation and execution (triggered by | ||
`torch_xla.sync()`), as shown in the following code: | ||
|
||
``` python | ||
```python | ||
import torch | ||
import torch_xla | ||
import torchvision | ||
|
@@ -24,24 +24,40 @@ res = model(input) | |
torch_xla.sync() | ||
``` | ||
|
||
The actual model compilation and device execution happens when | ||
`torch_xla.sync` is called. There are multiple drawback of this | ||
approach. | ||
While this approach enables performance optimizations, it introduces significant | ||
usability challenges. | ||
|
||
1. Users are often confused about when the framework is tracing and | ||
when the framework is executing. | ||
2. Non-core model code(data preprocessing for example) often generates | ||
some small pending execution that gets leaked into the main | ||
graph(step function) and causes recompilation. The recompilation of | ||
the whole graph is usually very expensive. | ||
3. It is hard to debug when/why recompilation happens. | ||
## Challenges with LazyTensor Mode | ||
|
||
To mitigate above issues we want to introduce the new UX with eager and | ||
compile. | ||
- **Ambiguity**: Developers struggle to distinguish between tracing and | ||
execution phases, complicating development and debugging. | ||
|
||
## Basic Usage | ||
- **Recompilation Overhead**: Non-core operations (e.g., data preprocessing) can | ||
leak into the graph, triggering expensive recompilations. | ||
|
||
``` python | ||
- **Debugging Difficulty**: Identifying the cause of recompilations is | ||
challenging due to the opaque nature of graph-building processes. | ||
|
||
## Eager Mode and `torch_xla.compile` | ||
|
||
To address these issues, PyTorch/XLA introduces an experimental eager mode | ||
(enabled via `torch_xla.experimental.eager_mode(True)`) and the | ||
`torch_xla.compile` API. **This shift aligns PyTorch/XLA more closely with | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you mean to emphasize this sentence? Seems a bit jarring to me but just a minor opinion. |
||
native PyTorch, prioritizing developer experience while preserving | ||
performance**. Eager mode is likely to become the default in future releases. | ||
|
||
- **Eager Mode**: Executes operations immediately, enhancing flexibility and | ||
debugging but at a performance cost for core tasks. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think "core tasks" is vague. Could simply remove: "[...] but at a performance cost." |
||
|
||
- **torch_xla.compile**: A decorator or wrapper that explicitly marks code | ||
(e.g., a model or function) for XLA compilation within an eager context, | ||
providing clear boundaries and immediate feedback. | ||
|
||
## How `torch_xla.compile` works | ||
|
||
Let's have a look at a basic usage of `torch_xla.compile`: | ||
|
||
```python | ||
import torch | ||
import torch_xla | ||
import torchvision | ||
|
@@ -60,31 +76,33 @@ input = torch.randn(64, 3, 224, 224).to(device) | |
res = compiled_model(input) | ||
``` | ||
|
||
Note that | ||
where the implementation of `torch_xla.compile` can be summarized as follows: | ||
|
||
1. Currently user has to manually enable the eager mode by | ||
`torch_xla.experimental.eager_mode(True)`. | ||
2. The region of the code that wants to be compiled should be wrapped | ||
by `torch_xla.compile`. | ||
1. **Disables Eager Mode**: Temporarily switches to tracing to build a | ||
computation graph. | ||
|
||
The implementation of the `torch_xla.compile` is actually pretty | ||
straight forward, it disables the eager mode when entering the target | ||
function and start tracing. It will call the `torch_xla.sync()` when | ||
target function returns and reenable the eager mode. You can expect the | ||
same perfomrance by using the `eager` + `compile` API compared to the | ||
existing `mark_step/sync` approach. | ||
2. **Traces Operations**: Records operations for XLA optimization. | ||
|
||
### Inference | ||
3. **Compiles and Executes**: Triggers compilation and execution via an | ||
internal `torch_xla.sync()` call. | ||
|
||
``` python | ||
torch_xla.experimental.eager_mode(True) | ||
compiled_model = torch.compile(model, backend="openxla") | ||
``` | ||
4. **Re-enables Eager Mode**: Resumes eager execution after compilation. | ||
|
||
This "eager-to-lazy-to-eager" transition abstracts synchronization complexity, | ||
balancing flexibility and performance. | ||
|
||
It is recommened to use the `torch.compile` instead of | ||
`torch_xla.compile` for inference to reduce the tracing overhad. | ||
## `torch_xla.compile` vs. `torch.compile` | ||
|
||
### Training | ||
The PyTorch ecosystem offers multiple compilation APIs, and understanding their | ||
distinct roles, especially within PyTorch/XLA, is crucial for optimal | ||
performance and development. | ||
|
||
- `torch_xla.compile` is optimized for PyTorch/XLA training workflows. Designed | ||
to work efficiently with the XLA backend for iterative training, it's the | ||
recommended API for compiling training loops due to its observed performance | ||
advantages. The best practice is to enclose the complete training step, e.g. | ||
forward pass, loss calculation, backward pass, and optimizer step, within a | ||
`step_fn` and then compiling this function. | ||
|
||
``` python | ||
torch_xla.experimental.eager_mode(True) | ||
|
@@ -100,33 +118,45 @@ def step_fn(model, data, target, loss_fn, optimizer): | |
step_fn = torch_xla.compile(step_fn) | ||
``` | ||
|
||
In training we asked user to refactor the `step_fn` out because it is | ||
usually better to compile the model's forward, backward and optimizer | ||
together. The long term goal is to also use `torch.compile` for training | ||
but right now we recommend user to use `torch_xla.compile`(for | ||
perfomrance reason). | ||
- `torch.compile` is PyTorch's general-purpose compilation API designed to | ||
accelerate PyTorch models across various backends. For PyTorch/XLA, it uses the | ||
`openxla` backend. We recommend `torch.compile` for PyTorch/XLA inference | ||
because it lowers tracing overhead, leading to more efficient static inference | ||
graphs. To use it with XLA, simply specify `backend="openxla"`. | ||
|
||
``` python | ||
torch_xla.experimental.eager_mode(True) | ||
compiled_model = torch.compile(model, backend="openxla") | ||
``` | ||
|
||
The long-term aim is for `torch.compile` to be the single compilation API for | ||
both training and inference on XLA. | ||
|
||
## Benchmark | ||
## Performance Benchmarks | ||
|
||
I run a 2 layer decoder only model training(it is pretty much just a | ||
llama2) with fake data on a single chip of v4-8 for 300 steps. Below is | ||
the number I observed. | ||
To quantify the performance impact of torch_xla.compile and eager mode, | ||
benchmarks were conducted under specific conditions. The benchmarks utilized a | ||
2-layer decoder-only model, similar to Llama2, trained with fake data. The | ||
training process spanned 300 steps on a single chip of a v4-8 TPU. The observed | ||
performance, measured in tokens per second, clearly illustrates the impact of | ||
different execution modes: | ||
|
||
Mode token/s | ||
--------------------------- --------- | ||
Tracing mode (base line) 147 | ||
Eager mode 65 | ||
Eager + torch_xla compile 147 | ||
| Mode | token/s | | ||
|-----------------------------|---------| | ||
| Tracing mode (base line) | 147 | | ||
| Eager mode | 65 | | ||
| Eager + torch_xla compile | 147 | | ||
|
||
: Eager mode benchmarks | ||
Eager mode with `torch_xla.compile` matches the performance of traditional | ||
LazyTensor tracing mode at `147` tokens/s, demonstrating a better user | ||
experience without performance loss. | ||
|
||
Eager mode can achieve ~45% performance of the fully compiled model for | ||
the decoder only model. For more information, see | ||
[train_decoder_only_base.py](https://github.com/pytorch/xla/blob/master/examples/train_decoder_only_base.py) | ||
Pure eager mode's performance is model-dependent; it achieves ~45% of the fully | ||
compiled model's performance for decoder-only models. However, for ResNet50, | ||
pure eager mode was significantly slower (about 1% of compiled mode). For more | ||
information, see [train_decoder_only_base.py](https://github.com/pytorch/xla/blob/master/examples/train_decoder_only_base.py) | ||
and [eager example](https://github.com/pytorch/xla/tree/master/examples/eager). | ||
Note that perfomrane of the eager mode is very model dependent. When I | ||
tried to run the resnet50, the eager mode perfomrance is \~1% of the | ||
compiled mode. We don't exepct user to use eager mode to execute the | ||
main training loop. Eager mode is meant to be used to handle non-core | ||
part of the training/inference logic(Data preprocessing, random number | ||
generations etc) or debug. | ||
This varying overhead means pure eager mode is not intended for main training or | ||
inference loops. Its utility lies in non-core tasks like data preprocessing, | ||
random number generation, custom utilities, or debugging, where immediate | ||
execution is prioritized over throughput. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We also need to call out that Wrapping the training loop in
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it could be clearer to explain that these non-core operations are all recorded. Maybe this phrasing: