Skip to content

Commit

Permalink
Merge branch 'main' into lang_paraphrases
Browse files Browse the repository at this point in the history
  • Loading branch information
mees committed May 8, 2024
2 parents 35cefca + 89045cc commit 1ceea1f
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 4 deletions.
20 changes: 17 additions & 3 deletions octo/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def apply_frame_transforms(
image_augment_kwargs: Union[dict, Mapping[str, dict]] = {},
resize_size: Union[Tuple[int, int], Mapping[str, Tuple[int, int]]] = {},
depth_resize_size: Union[Tuple[int, int], Mapping[str, Tuple[int, int]]] = {},
image_dropout_prob: float = 0.0,
image_dropout_keep_key: Optional[str] = None,
num_parallel_calls: int = tf.data.AUTOTUNE,
) -> dl.DLataset:
"""Applies common transforms that happen at a frame level. These transforms are usually more
Expand All @@ -159,6 +161,10 @@ def apply_frame_transforms(
keys (so pass an empty dict to skip resizing for all images).
depth_resize_size (Tuple[int, int]|Mapping[str, Tuple[int, int]]): Same as resize_size, but for depth
images.
image_dropout_prob (float): Probability of dropping out images, applied to each image key
independently. At least one image will always be present.
image_dropout_keep_key (str, optional): Optionally provide a key to always keep during image dropout
for example for image observations that are essential for action prediction.
num_parallel_calls (int): number of parallel calls for frame_map operations. Default to AUTOTUNE.
"""

Expand Down Expand Up @@ -186,14 +192,22 @@ def apply_obs_transform(fn: Callable[[dict], dict], frame: dict) -> dict:

if train:
# augment all images with the same seed, skipping padding images
def aug(frame: dict):
def aug_and_dropout(frame: dict):
seed = tf.random.uniform([2], maxval=tf.dtypes.int32.max, dtype=tf.int32)
dropout_fn = partial(
obs_transforms.image_dropout,
seed=seed,
dropout_prob=image_dropout_prob,
always_keep_key=image_dropout_keep_key,
)
aug_fn = partial(
obs_transforms.augment, seed=seed, augment_kwargs=image_augment_kwargs
)
return apply_obs_transform(aug_fn, frame)
frame = apply_obs_transform(dropout_fn, frame)
frame = apply_obs_transform(aug_fn, frame)
return frame

dataset = dataset.frame_map(aug, num_parallel_calls)
dataset = dataset.frame_map(aug_and_dropout, num_parallel_calls)

return dataset

Expand Down
57 changes: 56 additions & 1 deletion octo/data/obs_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Contains observation-level transforms used in the octo data pipeline. These transforms operate on the
"observation" dictionary, and are applied at a per-frame level.
"""
from typing import Mapping, Tuple, Union
from typing import Mapping, Optional, Tuple, Union

from absl import logging
import dlimp as dl
Expand Down Expand Up @@ -39,6 +39,61 @@ def augment(
return obs


def image_dropout(
obs: dict,
seed: tf.Tensor,
dropout_prob: float,
always_keep_key: Optional[str] = None,
) -> dict:
"""Independently drops out image keys, each with probability `dropout_prob`, but always keeps at least one
image present.
"""
image_keys = [key for key in obs if key.startswith("image_")]
if not image_keys:
return obs
pad_mask = tf.stack([obs["pad_mask_dict"][key] for key in image_keys])
# if any non-padding images exist, pick one of them to keep no matter what
shuffle_seed, seed = tf.unstack(tf.random.split(seed))

if always_keep_key:
assert (
always_keep_key in image_keys
), f"Specified always_keep_key {always_keep_key} not present in image_keys: {image_keys} during dropout."
always_keep_index = tf.constant(
image_keys.index(always_keep_key), dtype=tf.int64
)
else:
always_keep_index = tf.cond(
tf.reduce_any(pad_mask),
# pick a random index from the non-padding images
lambda: tf.random.experimental.stateless_shuffle(
tf.where(pad_mask)[:, 0], seed=shuffle_seed
)[0],
# all images are padding, so it doesn't matter
lambda: tf.constant(0, dtype=tf.int64),
)

# drop images independently, except for the one at always_keep_index
rands = tf.random.stateless_uniform([len(image_keys)], seed=seed)
pad_mask = tf.logical_and(
pad_mask,
tf.logical_or(
tf.range(len(image_keys), dtype=tf.int64) == always_keep_index,
rands > dropout_prob,
),
)

# perform the dropout and update pad_mask_dict
for i, key in enumerate(image_keys):
obs["pad_mask_dict"][key] = pad_mask[i]
obs[key] = tf.cond(
pad_mask[i],
lambda: obs[key],
lambda: tf.zeros_like(obs[key]),
)
return obs


def decode_and_resize(
obs: dict,
resize_size: Union[Tuple[int, int], Mapping[str, Tuple[int, int]]],
Expand Down
1 change: 1 addition & 0 deletions scripts/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def get_dataset_config(window_size=1):
),
"frame_transform_kwargs": dict(
resize_size=(256, 256),
image_dropout_prob=0.0,
image_augment_kwargs=dict(
random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]),
random_brightness=[0.2],
Expand Down
3 changes: 3 additions & 0 deletions scripts/configs/octo_pretrain_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ def get_config(config_string=None):
rephrase_prob=0.5,
),
),
frame_transform_kwargs=dict(
image_dropout_prob=0.5,
),
batch_size=128,
shuffle_buffer_size=500000,
balance_weights=True,
Expand Down

0 comments on commit 1ceea1f

Please sign in to comment.