diff --git a/docs/source/learn/eager.md b/docs/source/learn/eager.md index 4835f86e3e4..2174b2f9b50 100644 --- a/docs/source/learn/eager.md +++ b/docs/source/learn/eager.md @@ -1,14 +1,14 @@ -# Eager Mode + Compile API +# PyTorch/XLA Compilation: Eager Mode vs. Compile API -In this doc we will go over how to use PyTorch/XLA's new experimental -`eager` mode with the `compile` API. The goal is to make PyTorch/XLA -experience more aligned with the native PyTorch and make development -process easier. +## Overview -Currently PyTorch/XLA runs on the LazyTensor tracing mode by default. In -the following code +PyTorch/XLA integrates PyTorch with the XLA compiler to optimize deep learning +workloads across various hardware accelerators. Currently PyTorch/XLA uses the +LazyTensor tracing mode by default where operations are recorded into a +computation graph for deferred compilation and execution (triggered by +`torch_xla.sync()`), as shown in the following code: -``` python +```python import torch import torch_xla import torchvision @@ -24,24 +24,40 @@ res = model(input) torch_xla.sync() ``` -The actual model compilation and device execution happens when -`torch_xla.sync` is called. There are multiple drawback of this -approach. +While this approach enables performance optimizations, it introduces significant +usability challenges. -1. Users are often confused about when the framework is tracing and - when the framework is executing. -2. Non-core model code(data preprocessing for example) often generates - some small pending execution that gets leaked into the main - graph(step function) and causes recompilation. The recompilation of - the whole graph is usually very expensive. -3. It is hard to debug when/why recompilation happens. +## Challenges with LazyTensor Mode -To mitigate above issues we want to introduce the new UX with eager and -compile. +- **Ambiguity**: Developers struggle to distinguish between tracing and +execution phases, complicating development and debugging. -## Basic Usage +- **Recompilation Overhead**: Non-core operations (e.g., data preprocessing) can +leak into the graph, triggering expensive recompilations. -``` python +- **Debugging Difficulty**: Identifying the cause of recompilations is +challenging due to the opaque nature of graph-building processes. + +## Eager Mode and `torch_xla.compile` + +To address these issues, PyTorch/XLA introduces an experimental eager mode +(enabled via `torch_xla.experimental.eager_mode(True)`) and the +`torch_xla.compile` API. **This shift aligns PyTorch/XLA more closely with +native PyTorch, prioritizing developer experience while preserving +performance**. Eager mode is likely to become the default in future releases. + +- **Eager Mode**: Executes operations immediately, enhancing flexibility and +debugging but at a performance cost for core tasks. + +- **torch_xla.compile**: A decorator or wrapper that explicitly marks code +(e.g., a model or function) for XLA compilation within an eager context, +providing clear boundaries and immediate feedback. + +## How `torch_xla.compile` works + +Let's have a look at a basic usage of `torch_xla.compile`: + +```python import torch import torch_xla import torchvision @@ -60,31 +76,33 @@ input = torch.randn(64, 3, 224, 224).to(device) res = compiled_model(input) ``` -Note that +where the implementation of `torch_xla.compile` can be summarized as follows: -1. Currently user has to manually enable the eager mode by - `torch_xla.experimental.eager_mode(True)`. -2. The region of the code that wants to be compiled should be wrapped - by `torch_xla.compile`. +1. **Disables Eager Mode**: Temporarily switches to tracing to build a + computation graph. -The implementation of the `torch_xla.compile` is actually pretty -straight forward, it disables the eager mode when entering the target -function and start tracing. It will call the `torch_xla.sync()` when -target function returns and reenable the eager mode. You can expect the -same perfomrance by using the `eager` + `compile` API compared to the -existing `mark_step/sync` approach. +2. **Traces Operations**: Records operations for XLA optimization. -### Inference +3. **Compiles and Executes**: Triggers compilation and execution via an + internal `torch_xla.sync()` call. -``` python -torch_xla.experimental.eager_mode(True) -compiled_model = torch.compile(model, backend="openxla") -``` +4. **Re-enables Eager Mode**: Resumes eager execution after compilation. + +This "eager-to-lazy-to-eager" transition abstracts synchronization complexity, +balancing flexibility and performance. -It is recommened to use the `torch.compile` instead of -`torch_xla.compile` for inference to reduce the tracing overhad. +## `torch_xla.compile` vs. `torch.compile` -### Training +The PyTorch ecosystem offers multiple compilation APIs, and understanding their +distinct roles, especially within PyTorch/XLA, is crucial for optimal +performance and development. + +- `torch_xla.compile` is optimized for PyTorch/XLA training workflows. Designed +to work efficiently with the XLA backend for iterative training, it's the +recommended API for compiling training loops due to its observed performance +advantages. The best practice is to enclose the complete training step, e.g. +forward pass, loss calculation, backward pass, and optimizer step, within a +`step_fn` and then compiling this function. ``` python torch_xla.experimental.eager_mode(True) @@ -100,33 +118,45 @@ def step_fn(model, data, target, loss_fn, optimizer): step_fn = torch_xla.compile(step_fn) ``` -In training we asked user to refactor the `step_fn` out because it is -usually better to compile the model's forward, backward and optimizer -together. The long term goal is to also use `torch.compile` for training -but right now we recommend user to use `torch_xla.compile`(for -perfomrance reason). +- `torch.compile` is PyTorch's general-purpose compilation API designed to +accelerate PyTorch models across various backends. For PyTorch/XLA, it uses the +`openxla` backend. We recommend `torch.compile` for PyTorch/XLA inference +because it lowers tracing overhead, leading to more efficient static inference +graphs. To use it with XLA, simply specify `backend="openxla"`. + +``` python +torch_xla.experimental.eager_mode(True) +compiled_model = torch.compile(model, backend="openxla") +``` + +The long-term aim is for `torch.compile` to be the single compilation API for +both training and inference on XLA. -## Benchmark +## Performance Benchmarks -I run a 2 layer decoder only model training(it is pretty much just a -llama2) with fake data on a single chip of v4-8 for 300 steps. Below is -the number I observed. +To quantify the performance impact of torch_xla.compile and eager mode, +benchmarks were conducted under specific conditions. The benchmarks utilized a +2-layer decoder-only model, similar to Llama2, trained with fake data. The +training process spanned 300 steps on a single chip of a v4-8 TPU. The observed +performance, measured in tokens per second, clearly illustrates the impact of +different execution modes: - Mode token/s - --------------------------- --------- - Tracing mode (base line) 147 - Eager mode 65 - Eager + torch_xla compile 147 +| Mode | token/s | +|-----------------------------|---------| +| Tracing mode (base line) | 147 | +| Eager mode | 65 | +| Eager + torch_xla compile | 147 | - : Eager mode benchmarks +Eager mode with `torch_xla.compile` matches the performance of traditional +LazyTensor tracing mode at `147` tokens/s, demonstrating a better user +experience without performance loss. -Eager mode can achieve ~45% performance of the fully compiled model for -the decoder only model. For more information, see -[train_decoder_only_base.py](https://github.com/pytorch/xla/blob/master/examples/train_decoder_only_base.py) +Pure eager mode's performance is model-dependent; it achieves ~45% of the fully +compiled model's performance for decoder-only models. However, for ResNet50, +pure eager mode was significantly slower (about 1% of compiled mode). For more +information, see [train_decoder_only_base.py](https://github.com/pytorch/xla/blob/master/examples/train_decoder_only_base.py) and [eager example](https://github.com/pytorch/xla/tree/master/examples/eager). -Note that perfomrane of the eager mode is very model dependent. When I -tried to run the resnet50, the eager mode perfomrance is \~1% of the -compiled mode. We don't exepct user to use eager mode to execute the -main training loop. Eager mode is meant to be used to handle non-core -part of the training/inference logic(Data preprocessing, random number -generations etc) or debug. +This varying overhead means pure eager mode is not intended for main training or +inference loops. Its utility lies in non-core tasks like data preprocessing, +random number generation, custom utilities, or debugging, where immediate +execution is prioritized over throughput.