Skip to content

Improve torch_xla.compile documentation #9194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 95 additions & 65 deletions docs/source/learn/eager.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# Eager Mode + Compile API
# PyTorch/XLA Compilation: Eager Mode vs. Compile API

In this doc we will go over how to use PyTorch/XLA's new experimental
`eager` mode with the `compile` API. The goal is to make PyTorch/XLA
experience more aligned with the native PyTorch and make development
process easier.
## Overview

Currently PyTorch/XLA runs on the LazyTensor tracing mode by default. In
the following code
PyTorch/XLA integrates PyTorch with the XLA compiler to optimize deep learning
workloads across various hardware accelerators. Currently PyTorch/XLA uses the
LazyTensor tracing mode by default where operations are recorded into a
computation graph for deferred compilation and execution (triggered by
`torch_xla.sync()`), as shown in the following code:

``` python
```python
import torch
import torch_xla
import torchvision
Expand All @@ -24,24 +24,40 @@ res = model(input)
torch_xla.sync()
```

The actual model compilation and device execution happens when
`torch_xla.sync` is called. There are multiple drawback of this
approach.
While this approach enables performance optimizations, it introduces significant
usability challenges.

1. Users are often confused about when the framework is tracing and
when the framework is executing.
2. Non-core model code(data preprocessing for example) often generates
some small pending execution that gets leaked into the main
graph(step function) and causes recompilation. The recompilation of
the whole graph is usually very expensive.
3. It is hard to debug when/why recompilation happens.
## Challenges with LazyTensor Mode

To mitigate above issues we want to introduce the new UX with eager and
compile.
- **Ambiguity**: Developers struggle to distinguish between tracing and
execution phases, complicating development and debugging.

## Basic Usage
- **Recompilation Overhead**: Non-core operations (e.g., data preprocessing) can
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it could be clearer to explain that these non-core operations are all recorded. Maybe this phrasing:

- **Recompilation Overhead**: Whenever any part of the captured graph changes, `torch_xla.sync()` will
recompile the whole graph. Changes in non-core operations (e.g., data preprocessing) thus trigger expensive
recompilations.

leak into the graph, triggering expensive recompilations.

``` python
- **Debugging Difficulty**: Identifying the cause of recompilations is
challenging due to the opaque nature of graph-building processes.

## Eager Mode and `torch_xla.compile`

To address these issues, PyTorch/XLA introduces an experimental eager mode
(enabled via `torch_xla.experimental.eager_mode(True)`) and the
`torch_xla.compile` API. **This shift aligns PyTorch/XLA more closely with
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to emphasize this sentence? Seems a bit jarring to me but just a minor opinion.

native PyTorch, prioritizing developer experience while preserving
performance**. Eager mode is likely to become the default in future releases.

- **Eager Mode**: Executes operations immediately, enhancing flexibility and
debugging but at a performance cost for core tasks.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think "core tasks" is vague. Could simply remove: "[...] but at a performance cost."


- **torch_xla.compile**: A decorator or wrapper that explicitly marks code
(e.g., a model or function) for XLA compilation within an eager context,
providing clear boundaries and immediate feedback.

## How `torch_xla.compile` works

Let's have a look at a basic usage of `torch_xla.compile`:

```python
import torch
import torch_xla
import torchvision
Expand All @@ -60,31 +76,33 @@ input = torch.randn(64, 3, 224, 224).to(device)
res = compiled_model(input)
```

Note that
where the implementation of `torch_xla.compile` can be summarized as follows:

1. Currently user has to manually enable the eager mode by
`torch_xla.experimental.eager_mode(True)`.
2. The region of the code that wants to be compiled should be wrapped
by `torch_xla.compile`.
1. **Disables Eager Mode**: Temporarily switches to tracing to build a
computation graph.

The implementation of the `torch_xla.compile` is actually pretty
straight forward, it disables the eager mode when entering the target
function and start tracing. It will call the `torch_xla.sync()` when
target function returns and reenable the eager mode. You can expect the
same perfomrance by using the `eager` + `compile` API compared to the
existing `mark_step/sync` approach.
2. **Traces Operations**: Records operations for XLA optimization.

### Inference
3. **Compiles and Executes**: Triggers compilation and execution via an
internal `torch_xla.sync()` call.

``` python
torch_xla.experimental.eager_mode(True)
compiled_model = torch.compile(model, backend="openxla")
```
4. **Re-enables Eager Mode**: Resumes eager execution after compilation.

This "eager-to-lazy-to-eager" transition abstracts synchronization complexity,
balancing flexibility and performance.

It is recommened to use the `torch.compile` instead of
`torch_xla.compile` for inference to reduce the tracing overhad.
## `torch_xla.compile` vs. `torch.compile`

### Training
The PyTorch ecosystem offers multiple compilation APIs, and understanding their
distinct roles, especially within PyTorch/XLA, is crucial for optimal
performance and development.

- `torch_xla.compile` is optimized for PyTorch/XLA training workflows. Designed
to work efficiently with the XLA backend for iterative training, it's the
recommended API for compiling training loops due to its observed performance
advantages. The best practice is to enclose the complete training step, e.g.
forward pass, loss calculation, backward pass, and optimizer step, within a
`step_fn` and then compiling this function.

``` python
torch_xla.experimental.eager_mode(True)
Expand All @@ -100,33 +118,45 @@ def step_fn(model, data, target, loss_fn, optimizer):
step_fn = torch_xla.compile(step_fn)
```

In training we asked user to refactor the `step_fn` out because it is
usually better to compile the model's forward, backward and optimizer
together. The long term goal is to also use `torch.compile` for training
but right now we recommend user to use `torch_xla.compile`(for
perfomrance reason).
- `torch.compile` is PyTorch's general-purpose compilation API designed to
accelerate PyTorch models across various backends. For PyTorch/XLA, it uses the
`openxla` backend. We recommend `torch.compile` for PyTorch/XLA inference
because it lowers tracing overhead, leading to more efficient static inference
graphs. To use it with XLA, simply specify `backend="openxla"`.

``` python
torch_xla.experimental.eager_mode(True)
compiled_model = torch.compile(model, backend="openxla")
```

The long-term aim is for `torch.compile` to be the single compilation API for
both training and inference on XLA.

## Benchmark
## Performance Benchmarks

I run a 2 layer decoder only model training(it is pretty much just a
llama2) with fake data on a single chip of v4-8 for 300 steps. Below is
the number I observed.
To quantify the performance impact of torch_xla.compile and eager mode,
benchmarks were conducted under specific conditions. The benchmarks utilized a
2-layer decoder-only model, similar to Llama2, trained with fake data. The
training process spanned 300 steps on a single chip of a v4-8 TPU. The observed
performance, measured in tokens per second, clearly illustrates the impact of
different execution modes:

Mode token/s
--------------------------- ---------
Tracing mode (base line) 147
Eager mode 65
Eager + torch_xla compile 147
| Mode | token/s |
|-----------------------------|---------|
| Tracing mode (base line) | 147 |
| Eager mode | 65 |
| Eager + torch_xla compile | 147 |

: Eager mode benchmarks
Eager mode with `torch_xla.compile` matches the performance of traditional
LazyTensor tracing mode at `147` tokens/s, demonstrating a better user
experience without performance loss.

Eager mode can achieve ~45% performance of the fully compiled model for
the decoder only model. For more information, see
[train_decoder_only_base.py](https://github.com/pytorch/xla/blob/master/examples/train_decoder_only_base.py)
Pure eager mode's performance is model-dependent; it achieves ~45% of the fully
compiled model's performance for decoder-only models. However, for ResNet50,
pure eager mode was significantly slower (about 1% of compiled mode). For more
information, see [train_decoder_only_base.py](https://github.com/pytorch/xla/blob/master/examples/train_decoder_only_base.py)
and [eager example](https://github.com/pytorch/xla/tree/master/examples/eager).
Note that perfomrane of the eager mode is very model dependent. When I
tried to run the resnet50, the eager mode perfomrance is \~1% of the
compiled mode. We don't exepct user to use eager mode to execute the
main training loop. Eager mode is meant to be used to handle non-core
part of the training/inference logic(Data preprocessing, random number
generations etc) or debug.
This varying overhead means pure eager mode is not intended for main training or
inference loops. Its utility lies in non-core tasks like data preprocessing,
random number generation, custom utilities, or debugging, where immediate
execution is prioritized over throughput.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need to call out that torch_xla.compile is independently useful, even not in eager mode. That's why torchprime does: https://github.com/AI-Hypercomputer/torchprime/blob/31c450e82c6273f50f9815351f6fbebb42903f58/torchprime/torch_xla_models/train.py#L330

Wrapping the training loop in torch_xla.compile provides a few benefits:

  • There's no mark_step anywhere
  • Dataloading operations don't leak into the training loop graph. This benefit is similar to if eager mode was turned on. The only difference is that the dataloading operations are captured into a separate graph as opposed to running eagerly.
  • torch_xla.compile(full_graph=True) will catch accidental graph breaks

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since torch_xla.compile isn't really tied to eager mode, I think it could be clearer to rename the document heading as such. Alternatively, we could move the contents about torch_xla.compile to a separate compile.md, and then talk about its interaction with eager mode in this markdown.