Skip to content

Commit

Permalink
add adam training (#32)
Browse files Browse the repository at this point in the history
* add adam training
* fix tests for new decoding api
  • Loading branch information
hamishivi authored Aug 5, 2022
1 parent b963f7d commit f599682
Show file tree
Hide file tree
Showing 14 changed files with 149 additions and 33 deletions.
5 changes: 1 addition & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,7 @@ jobs:
test -d .venv || virtualenv -p $(which python) --copies --reset-app-data .venv
. .venv/bin/activate
pip install -e .[dev]
cd ..
git clone --branch=main https://github.com/google-research/t5x # TODO: pin to specific commit.
cd t5x
python3 -m pip install -e . -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install git+https://github.com/google-research/t5x.git@main#egg=t5x
# install promptsource, and fix the seqio dependency
pip install promptsource
pip uninstall -y t5
Expand Down
24 changes: 24 additions & 0 deletions gins/partial_train_adafactor.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# T5.1.1 Base model.
from __gin__ import dynamic_registration

from t5x import adafactor
from t5x import optimizers
from hyper_task_descriptions import utils as hyper_utils

# gin that allows partial training based on regex matching.

# ------------------- Partial loading ------------------------------------------------
OPTIMIZER = @optimizers.MultiOptimizer()
# note you can add more traversals if you want different optimizer settings
# for dfferent parts of the model.
# See https://github.com/google-research/t5x/blob/main/docs/usage/gin.md#scoping
# for how to create multiple specialised instances of the same class.
optimizers.MultiOptimizer:
traversals_and_optimizers = ((@optim.ModelParamTraversal(),
@adafactor.Adafactor()),)
optim.ModelParamTraversal:
filter_fn = @hyper_utils.match_any()
# MultiOptimizer will match any parameter with a flattened name that
# matches *any* of the regular expressions in the list.
PROMPT_REGEX = [".*/hyper/.*"]
hyper_utils.match_any.regexes = %PROMPT_REGEX
38 changes: 38 additions & 0 deletions gins/partial_train_adam.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __gin__ import dynamic_registration

import optax
from t5x import utils

from hyper_task_descriptions import utils as hyper_utils


# multi optimizer - we map anything that matches param_labels to adamw, others dont train
# note we use optaxs way of doing things here - the t5x multoptimizer didnt work for some
# reason.
OPTIMIZER = @hyper_utils.multi_transform()
hyper_utils.multi_transform:
transforms = {"train": @optax.adam(), "freeze": @optax.set_to_zero()}
param_labels = @hyper_utils.match_any_optax()

# we only train params that match this regex
hyper_utils.match_any_optax.regexes = [".*hyper.*"]

optax.adam:
learning_rate = @utils.create_learning_rate_scheduler()
# adamw params below. See https://optax.readthedocs.io/en/latest/api.html#optax.adamw
# weight_decay = 0
# mask = @hyper_utils.match_any_optax_inverse()

# for adamw, a common case is not applying wd to layer norms and biases (but no bias in t5)
#hyper_utils.match_any_optax_inverse.regexes = [".*/LayerNorm/.*"]


# WARNING: t5x will log starting from the pretrained model step,
# but optax calls this starting from 0. So ignore the tensorboard
# learning rate logging.
utils.create_learning_rate_scheduler:
factors = 'constant * linear_warmup'
base_learning_rate = 1e-5
warmup_steps = 1000
step_offset = 0 # our steps start at 0 no matter what with optax.

20 changes: 2 additions & 18 deletions gins/partial_train.gin → gins/restore_pretrained.gin
Original file line number Diff line number Diff line change
@@ -1,26 +1,8 @@
# T5.1.1 Base model.
from __gin__ import dynamic_registration

from flax import optim
from t5x import adafactor
from t5x import optimizers
from t5x import utils
from hyper_task_descriptions import utils as hyper_utils

# gin that allows partial training based on regex matching.

# ------------------- Partial loading ------------------------------------------------
# OPTIMIZER = @optimizers.MultiOptimizer()
# optimizers.MultiOptimizer:
# traversals_and_optimizers = ((@optim.ModelParamTraversal(),
# @adafactor.Adafactor()),)
# optim.ModelParamTraversal:
# filter_fn = @hyper_utils.match_any()
# # MultiOptimizer will match any parameter with a flattened name that
# # matches any of these regular expressions.
# PROMPT_REGEX = [".*/hyper/.*"]
# # PROMPT_REGEX = ["^((?!roberta).)*$"]
# hyper_utils.match_any.regexes = %PROMPT_REGEX

# These setting allow us to partially reload a checkpoint, that is, we can load
# most of the model weights from the checkpoint, without it complaining that we
Expand All @@ -39,6 +21,8 @@ utils.RestoreCheckpointConfig:
# the checkpoint.

# We skip hypernetwork parameters
# any matching regex will not be restored from the checkpoint.
# anything not matching not in the checkpoint will cause an error.
assignment_map = (
(r"^.*hyper.*$", None),
)
1 change: 1 addition & 0 deletions gins/t0_train.gin
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import __main__ as train_script

include "t5x/configs/runs/finetune.gin"
include "gins/t0.gin" # This overrides some default config in `t5x/configs/runs/finetune.gin`
include "gins/restore_pretrained.gin" # for loading from checkpoints

TASK_FEATURE_LENGTHS = {"inputs": 1024, "hyper_inputs": 512, "task_names": 1, "targets": 256}
MIXTURE_OR_TASK_NAME = "t0_train"
Expand Down
7 changes: 4 additions & 3 deletions hyper_task_descriptions/modeling/hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,15 +339,16 @@ def _compute_logits(

def _compute_logits_from_slice(
self,
flat_ids: jnp.ndarray,
flat_cache: Mapping[str, jnp.ndarray],
decoding_state: decoding.DecodingState,
params: PyTreeDef,
encoded_inputs: jnp.ndarray,
adaptations: Tuple[jnp.ndarray, ...],
raw_inputs: jnp.ndarray,
max_decode_length: int,
) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]:
"""Token slice to logits from decoder model."""
flat_ids = decoding_state.cur_token
flat_cache = decoding_state.cache
# flat_ids: [batch * beam, seq_len=1]
# cache is expanded inside beam_search to become flat_cache
# flat_cache: [batch * beam, num_heads, depth_per_head, max_decode_len]
Expand Down Expand Up @@ -710,7 +711,7 @@ def loss_fn(
z_loss=self._z_loss,
loss_normalizing_factor=loss_normalizing_factor,
)
loss += cos_loss * cosine_loss_multiplier # upweight since otherwise ce loss dominates
# loss += cos_loss * cosine_loss_multiplier # upweight since otherwise ce loss dominates
metrics = self._compute_metrics(
logits=logits,
targets=batch["decoder_target_tokens"],
Expand Down
50 changes: 49 additions & 1 deletion hyper_task_descriptions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
import re
from typing import Any, Callable, Optional, Sequence, Tuple

from t5x import partitioning
import flax
import optax
from flax.core import frozen_dict
from t5x import optimizers, partitioning

PartitionRule = Tuple[str, Optional[partitioning.PartitionSpec]]

Expand All @@ -43,3 +46,48 @@ def _match_any(path, _):
return any(regex.fullmatch(path) for regex in regexes)

return _match_any


def inverse_match_any(regexes: Sequence[str]) -> Callable[[str, Any], bool]:
"""Inverse of the above"""
regexes = tuple(re.compile(regex) for regex in regexes)

def _match_any(path, _):
"""False if path matches any regex in regexs, true otherwise."""
return not any(regex.fullmatch(path) for regex in regexes)

return _match_any


def flattened_traversal(fn):
"""Returns function that is called with `(path, param)` instead of pytree."""

def mask(tree):
flat = flax.traverse_util.flatten_dict(tree, sep="/")
masked_tree = flax.traverse_util.unflatten_dict(
{k: fn(k, v) for k, v in flat.items()}, sep="/"
)
return frozen_dict.freeze(masked_tree)

return mask


def match_any_optax(regexes: Sequence[str]) -> Callable[[str, Any], bool]:
regexes = tuple(re.compile(regex) for regex in regexes)
label_fn = flattened_traversal(
lambda path, _: "train" if any(regex.fullmatch(path) for regex in regexes) else "freeze"
)
return label_fn


# inverse match, mainly for adamw weight decay - see partial_adamw.gin for an example of how this is applied
def match_any_optax_inverse(regexes: Sequence[str]) -> Callable[[str, Any], bool]:
regexes = tuple(re.compile(regex) for regex in regexes)
label_fn = flattened_traversal(
lambda path, _: "freeze" if any(regex.fullmatch(path) for regex in regexes) else "train"
)
return label_fn


# t5x doesnt wrap this but i need it
multi_transform = optimizers.wrap_optax_optimizer(optax.multi_transform)
4 changes: 2 additions & 2 deletions scripts/t0_eval.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ EXPERIMENT_NAME=$1
CHECKPOINT_NAME=$2

# model checkpoint location
MODEL_DIR="gs://hamishi-tpu-bucket/${EXPERIMENT_NAME}/model/${CHECKPOINT_NAME}"
MODEL_DIR="gs://hamishi-us-bucket/${EXPERIMENT_NAME}/model/${CHECKPOINT_NAME}"
# where to put eval results
EVAL_OUTPUT_DIR="gs://hamishi-tpu-bucket/${EXPERIMENT_NAME}/eval"
EVAL_OUTPUT_DIR="gs://hamishi-us-bucket/${EXPERIMENT_NAME}/eval"

# we go offline to avoid constant calls to get basic info (happens even when cached)
# for your first run, you will probably need to run all these calls :(
Expand Down
2 changes: 1 addition & 1 deletion scripts/t0_reg_eval.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ MODEL_DIR=$1
SAVE_DIR=$2

# where to put eval results
EVAL_OUTPUT_DIR="gs://hamishi-tpu-bucket/${SAVE_DIR}"
EVAL_OUTPUT_DIR="gs://hamishi-us-bucket/${SAVE_DIR}"

# we go offline to avoid constant calls to get basic info (happens even when cached)
# for your first run, you will probably need to run all these calls :(
Expand Down
2 changes: 1 addition & 1 deletion scripts/t0_reg_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.train \
--gin.TRAIN_STEPS=1212200 \
--gin.partitioning.PjitPartitioner.num_partitions=8 \
--gin.INITIAL_CHECKPOINT_PATH=\"gs://hamishi-us-bucket/t0_3b_further_train/model/checkpoint_1126000\" \
--tfds_data_dir="gs://hamishi-tpu-bucket/t0_data/data"
--tfds_data_dir="gs://hamishi-us-bucket/t0_data/data"
3 changes: 3 additions & 0 deletions scripts/tpu_setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ python3 -m pip install promptsource
# i use a new feature in t5.data
python3 -m pip uninstall -y t5
python3 -m pip install git+https://github.com/google-research/text-to-text-transfer-transformer.git
# use a compatible version of optax
python3 -m pip uninstall -y optax
python3 -m pip install optax==0.1.2
# custom fixed seqio
python3 -m pip uninstall -y seqio seqio-nightly
python3 -m pip install git+https://github.com/hamishivi/seqio.git
Expand Down
4 changes: 2 additions & 2 deletions scripts/t0_train.sh → scripts/train_from_t0.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ MODEL_DIR="gs://hamishi-us-bucket/${EXPERIMENT_NAME}/model"
HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.train \
--gin_search_paths=gins \
--gin_file="hyper_xl.gin" \
--gin_file="partial_train.gin" \
--gin_file="t0_train.gin" \
--gin_file="partial_train_adam.gin" \
--gin.MODEL_DIR=\"${MODEL_DIR}\" \
--gin.TRAIN_STEPS=1212200 \
--gin.partitioning.PjitPartitioner.num_partitions=8 \
--gin.INITIAL_CHECKPOINT_PATH=\"gs://hamishi-us-bucket/t0_3b_further_train/model/checkpoint_1126000\" \
--tfds_data_dir="gs://hamishi-tpu-bucket/t0_data/data"
--tfds_data_dir="gs://hamishi-us-bucket/t0_data/data"
18 changes: 18 additions & 0 deletions scripts/train_from_t5.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# name of experiment folder
EXPERIMENT_NAME=$1

# where model will be saved
MODEL_DIR="gs://hamishi-us-bucket/${EXPERIMENT_NAME}/model"

# we go offline to avoid constant calls to get basic info (happens even when cached)
# for your first run, you will probably need to run all these calls :(
HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.train \
--gin_search_paths=gins \
--gin_file="hyper_xl.gin" \
--gin_file="t0_train.gin" \
--gin_file="partial_train_adam.gin" \
--gin.MODEL_DIR=\"${MODEL_DIR}\" \
--gin.TRAIN_STEPS=1212200 \
--gin.partitioning.PjitPartitioner.num_partitions=8 \
--gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000\" \
--tfds_data_dir="gs://hamishi-us-bucket/t0_data/data"
4 changes: 3 additions & 1 deletion tests/modeling/hyper_transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,9 @@ def mock_init(self):
self.assertLen(tokens_to_logits_mock.call_args_list, max_decode_len)
for tokens_call in tokens_to_logits_mock.call_args_list:
# Inputs: [B * Be, 1]
inputs, cache = tokens_call[0]
decoding_state = tokens_call[0][0]
inputs = decoding_state.cur_token
cache = decoding_state.cache
cache = flax.core.unfreeze(cache)
# Cache: [B * Be, 1] * #Layers
cache_keys = [
Expand Down

0 comments on commit f599682

Please sign in to comment.