Skip to content

Commit

Permalink
Update README, fix pretrain config
Browse files Browse the repository at this point in the history
  • Loading branch information
kvablack committed Dec 13, 2023
1 parent 116cf70 commit 4981e35
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 44 deletions.
70 changes: 29 additions & 41 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
![](https://github.com/rail-berkeley/octo/workflows/run-debug/badge.svg)
![](https://github.com/rail-berkeley/octo/workflows/pre-commit/badge.svg)

This repo contains code for training and finetuning Octo generalist robotic models (GRMs).
Octo models are transformer-based diffusion policies, trained on a diverse mix of >1M robot trajectories.
This repo contains code for training and finetuning Octo generalist robotic policies (GRPs).
Octo models are transformer-based diffusion policies, trained on a diverse mix of 800k robot trajectories.

![Octo model](docs/assets/teaser.png)

Out of the box, Octo supports multiple RGB camera inputs, can control various robot arms,
and can be instructed via language commands or goal images.
Octo uses a modular attention structure in its transformer backbone, allowing it to be effectively fine-tuned
Octo uses a modular attention structure in its transformer backbone, allowing it to be effectively finetuned
to robot setups with new sensory inputs, action spaces, and morphologies, using only a small target domain
dataset and accessible compute budgets.

Expand All @@ -28,19 +28,19 @@ pip install --upgrade "jax[cuda11_pip]==0.4.20" -f https://storage.googleapis.co
```

For TPU
```
```bash
pip install --upgrade "jax[tpu]==0.4.20" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
```
See the [Jax Github page](https://github.com/google/jax) for more details on installing Jax.

Test the installation by fine-tuning on the debug dataset:
Test the installation by finetuning on the debug dataset:
```bash
python scripts/finetune.py --config.pretrained_path=hf://rail-berkeley/octo-small
python scripts/finetune.py --config.pretrained_path=hf://rail-berkeley/octo-small --debug
```

## Checkpoints

You can find pre-trained Octo checkpoints [here](https://huggingface.co/rail-berkeley).
You can find pretrained Octo checkpoints [here](https://huggingface.co/rail-berkeley).
At the moment we provide the following model versions:

| Model | Inference on 1x NVIDIA 4090 | Size |
Expand All @@ -51,54 +51,53 @@ At the moment we provide the following model versions:

## Examples

We provide simple [example scripts](examples) that demonstrate how to inference and finetune Octo models,
We provide simple [example scripts](examples) that demonstrate how to use and finetune Octo models,
as well as how to use our data loader independently. We provide the following examples:

| | |
|-------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------|
| [Octo Inference](examples/01_inference_pretrained.ipynb) | Minimal example for loading and inferencing a pre-trained Octo model |
| [Octo Finetuning](examples/02_finetune_new_observation_action.py) | Minimal example for finetuning a pre-trained Octo models on a small dataset with new observation + action space |
| [Octo Rollout](examples/03_eval_finetuned.py) | Run a rollout of a pre-trained Octo policy in a Gym environment |
| [Octo Robot Eval](examples/04_eval_finetuned_on_robot.py) | Evaluate a pre-trained Octo model on a real WidowX robot |
| [Octo Inference](examples/01_inference_pretrained.ipynb) | Minimal example for loading and running a pretrained Octo model |
| [Octo Finetuning](examples/02_finetune_new_observation_action.py) | Minimal example for finetuning a pretrained Octo models on a small dataset with a new observation and action space |
| [Octo Rollout](examples/03_eval_finetuned.py) | Run a rollout of a pretrained Octo policy in a Gym environment |
| [Octo Robot Eval](examples/04_eval_finetuned_on_robot.py) | Evaluate a pretrained Octo model on a real WidowX robot |
| [OpenX Dataloader Intro](examples/05_dataloading.ipynb) | Walkthrough of the features of our Open X-Embodiment data loader |


## Octo Pre-Training
## Octo Pretraining

To reproduce our Octo pre-training on >1M robot trajectories, run:
```
python scripts/train.py --config scripts/configs/config.py:vit_s --name=octo --config.dataset_kwargs.oxe_kwargs.data_dir=... --config.dataset_kwargs.oxe_kwargs.data_mix=oxe_magic_soup ...
To reproduce our Octo pretraining on 800k robot trajectories, run:
```bash
python scripts/train.py --config scripts/configs/octo_pretrain_config.py:<size> --name=octo --config.dataset_kwargs.oxe_kwargs.data_dir=... --config.dataset_kwargs.oxe_kwargs.data_mix=oxe_magic_soup ...
```
You can modify hyperparameters like dataset, batch size etc. in [config.py](scripts/configs/config.py).

To download the pre-training dataset from the [Open X-Embodiment Dataset](https://robotics-transformer-x.github.io/),
To download the pretraining dataset from the [Open X-Embodiment Dataset](https://robotics-transformer-x.github.io/),
install the [rlds_dataset_mod package](https://github.com/kpertsch/rlds_dataset_mod)
and run the [prepare_open_x.sh script](https://github.com/kpertsch/rlds_dataset_mod/blob/main/prepare_open_x.sh).
The total size of the pre-processed dataset is ~1.2TB.

We run pre-training using a TPUv4-128 pod in 8 hours for the Octo-S model and in 14 hours for Octo-B.
We run pretraining using a TPUv4-128 pod in 8 hours for the Octo-S model and in 14 hours for Octo-B.


## Octo Finetuning

We provide a [minimal example](examples/02_finetune_new_observation_action.py) for finetuning with new observations and action space.
We provide a [minimal example](examples/02_finetune_new_observation_action.py) for finetuning with a new observation and action space.

We also provide a more advanced finetuning script that allows to change hyperparameters via a config and logs finetuning
We also provide a more advanced finetuning script that allows you to change hyperparameters via a config file and logs finetuning
metrics. To run advanced finetuning, use:
```
```bash
python scripts/finetune.py --config.pretrained_path=hf://rail-berkeley/octo-small
```

We offer three finetuning modes depending on the parts of the model that are kept frozen: ```head_only```, ```head_mlp_only``` and ```full``` to finetune the full model.
Besides, one can specify the task type to finetune with ```image_conditioned```, ```language_conditioned``` or ```multimodal``` for both.
We offer three finetuning modes depending on the parts of the model that are kept frozen: ```head_only```, ```head_mlp_only```, and ```full``` to finetune the full model.
Additionally, one can specify the task type to finetune with: ```image_conditioned```, ```language_conditioned``` or ```multimodal``` for both.
For example, to finetune the full transformer with image inputs only use:
```--config=finetune_config.py:full,image_conditioned```
```--config=finetune_config.py:full,image_conditioned```.


## Octo Evaluation

Loading and inferencing a trained Octo model is as easy as:
```
Loading and running a trained Octo model is as easy as:
```python
from octo.model import OctoModel

model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small")
Expand All @@ -118,21 +117,10 @@ To evaluate on your own environment, simply wrap it in a Gym interface and follo
| | File | Description |
|---------------------|---------------------------------------------------------|-------------------------------------------------------------------------------|
| Hyperparameters | [config.py](scripts/configs/config.py) | Defines all hyperparameters for the training run. |
| Training Loop | [train.py](scripts/train.py) | Main training script. |
| Finetuning Script | [finetune.py](scripts/finetune.py) | Main finetuning script. |
| Pretraining Loop | [train.py](scripts/train.py) | Main pretraining script. |
| Finetuning Loop | [finetune.py](scripts/finetune.py) | Main finetuning script. |
| Datasets | [dataset.py](octo/data/dataset.py) | Functions for creating single / interleaved datasets + data augmentation. |
| Tokenizers | [tokenizers.py](octo/model/components/tokenizers.py) | Tokenizers that encode image / text inputs into tokens. |
| Octo Model | [octo_model.py](octo/model/octo_model.py) | Main entrypoint for interacting with Octo models, loading, saving, inference. |
| Octo Model | [octo_model.py](octo/model/octo_model.py) | Main entrypoint for interacting with Octo models: loading, saving, and inference. |
| Model Architecture | [octo_module.py](octo/model/octo_module.py) | Combines token sequencing, transformer backbone and readout heads. |
| Visualization | [visualization_lib.py](octo/utils/visualization_lib.py) | Utilities for offline qualitative & quantitative eval. |


## Contributing
Experimental things and training/eval scripts should go in `experiments/<your_name>`. To make any changes to files outside of your experiments directory, please open a pull request.

Steps to contribute:
1. Fork the repo and create your branch from `master`.
2. Use `pre-commit` to enable code checks and auto-formatting.
3. Test that a basic training starts with the debug dataset with: ```
python scripts/finetune.py --config.pretrained_path=hf://rail-berkeley/octo-small
```
1 change: 0 additions & 1 deletion scripts/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ def get_dataset_config(window_size=1):
# oxe_kwargs will generate dataset_kwargs_list and sampling weights
"oxe_kwargs": dict(
data_mix=placeholder(str),
# for v4 TPUs: "gs://rail-octo-central2/resize_336_336"
data_dir=placeholder(str),
load_camera_views=("primary", "wrist"),
load_depth=False,
Expand Down
1 change: 0 additions & 1 deletion scripts/configs/finetune_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def get_config(config_string="full,multimodal"):

FINETUNING_KWARGS = {
"name": "bridge_dataset",
# On v4, this might be "gs://rail-octo-central2/resize_256_256"
"data_dir": "./tests/debug_dataset",
"image_obs_keys": {"primary": "image_0", "wrist": None},
"state_obs_keys": ["state", None],
Expand Down
7 changes: 6 additions & 1 deletion scripts/configs/octo_pretrain_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from copy import deepcopy
import imp
import os

from ml_collections import ConfigDict
from scripts.configs.config import get_config as get_base_config

get_base_config = imp.load_source(
"config", os.path.join(os.path.dirname(__file__), "config.py")
).get_config

from octo.data.utils.text_processing import HFTokenizer
from octo.model.components.action_heads import DiffusionActionHead
Expand Down

0 comments on commit 4981e35

Please sign in to comment.