diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2ffb5c89..579016c7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -79,10 +79,7 @@ jobs: test -d .venv || virtualenv -p $(which python) --copies --reset-app-data .venv . .venv/bin/activate pip install -e .[dev] - cd .. - git clone --branch=main https://github.com/google-research/t5x # TODO: pin to specific commit. - cd t5x - python3 -m pip install -e . -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + pip install git+https://github.com/google-research/t5x.git@main#egg=t5x # install promptsource, and fix the seqio dependency pip install promptsource pip uninstall -y t5 diff --git a/gins/partial_train_adafactor.gin b/gins/partial_train_adafactor.gin new file mode 100644 index 00000000..ff265fff --- /dev/null +++ b/gins/partial_train_adafactor.gin @@ -0,0 +1,24 @@ +# T5.1.1 Base model. +from __gin__ import dynamic_registration + +from t5x import adafactor +from t5x import optimizers +from hyper_task_descriptions import utils as hyper_utils + +# gin that allows partial training based on regex matching. + +# ------------------- Partial loading ------------------------------------------------ +OPTIMIZER = @optimizers.MultiOptimizer() +# note you can add more traversals if you want different optimizer settings +# for dfferent parts of the model. +# See https://github.com/google-research/t5x/blob/main/docs/usage/gin.md#scoping +# for how to create multiple specialised instances of the same class. +optimizers.MultiOptimizer: + traversals_and_optimizers = ((@optim.ModelParamTraversal(), + @adafactor.Adafactor()),) +optim.ModelParamTraversal: + filter_fn = @hyper_utils.match_any() +# MultiOptimizer will match any parameter with a flattened name that +# matches *any* of the regular expressions in the list. +PROMPT_REGEX = [".*/hyper/.*"] +hyper_utils.match_any.regexes = %PROMPT_REGEX diff --git a/gins/partial_train_adam.gin b/gins/partial_train_adam.gin new file mode 100644 index 00000000..12f50f69 --- /dev/null +++ b/gins/partial_train_adam.gin @@ -0,0 +1,38 @@ +from __gin__ import dynamic_registration + +import optax +from t5x import utils + +from hyper_task_descriptions import utils as hyper_utils + + +# multi optimizer - we map anything that matches param_labels to adamw, others dont train +# note we use optaxs way of doing things here - the t5x multoptimizer didnt work for some +# reason. +OPTIMIZER = @hyper_utils.multi_transform() +hyper_utils.multi_transform: + transforms = {"train": @optax.adam(), "freeze": @optax.set_to_zero()} + param_labels = @hyper_utils.match_any_optax() + +# we only train params that match this regex +hyper_utils.match_any_optax.regexes = [".*hyper.*"] + +optax.adam: + learning_rate = @utils.create_learning_rate_scheduler() + # adamw params below. See https://optax.readthedocs.io/en/latest/api.html#optax.adamw + # weight_decay = 0 + # mask = @hyper_utils.match_any_optax_inverse() + +# for adamw, a common case is not applying wd to layer norms and biases (but no bias in t5) +#hyper_utils.match_any_optax_inverse.regexes = [".*/LayerNorm/.*"] + + +# WARNING: t5x will log starting from the pretrained model step, +# but optax calls this starting from 0. So ignore the tensorboard +# learning rate logging. +utils.create_learning_rate_scheduler: + factors = 'constant * linear_warmup' + base_learning_rate = 1e-5 + warmup_steps = 1000 + step_offset = 0 # our steps start at 0 no matter what with optax. + diff --git a/gins/partial_train.gin b/gins/restore_pretrained.gin similarity index 59% rename from gins/partial_train.gin rename to gins/restore_pretrained.gin index e76d2f28..fcfa4de2 100644 --- a/gins/partial_train.gin +++ b/gins/restore_pretrained.gin @@ -1,26 +1,8 @@ # T5.1.1 Base model. from __gin__ import dynamic_registration -from flax import optim -from t5x import adafactor -from t5x import optimizers from t5x import utils -from hyper_task_descriptions import utils as hyper_utils -# gin that allows partial training based on regex matching. - -# ------------------- Partial loading ------------------------------------------------ -# OPTIMIZER = @optimizers.MultiOptimizer() -# optimizers.MultiOptimizer: -# traversals_and_optimizers = ((@optim.ModelParamTraversal(), -# @adafactor.Adafactor()),) -# optim.ModelParamTraversal: -# filter_fn = @hyper_utils.match_any() -# # MultiOptimizer will match any parameter with a flattened name that -# # matches any of these regular expressions. -# PROMPT_REGEX = [".*/hyper/.*"] -# # PROMPT_REGEX = ["^((?!roberta).)*$"] -# hyper_utils.match_any.regexes = %PROMPT_REGEX # These setting allow us to partially reload a checkpoint, that is, we can load # most of the model weights from the checkpoint, without it complaining that we @@ -39,6 +21,8 @@ utils.RestoreCheckpointConfig: # the checkpoint. # We skip hypernetwork parameters + # any matching regex will not be restored from the checkpoint. + # anything not matching not in the checkpoint will cause an error. assignment_map = ( (r"^.*hyper.*$", None), ) \ No newline at end of file diff --git a/gins/t0_train.gin b/gins/t0_train.gin index 0c22c25a..363d61e3 100644 --- a/gins/t0_train.gin +++ b/gins/t0_train.gin @@ -11,6 +11,7 @@ import __main__ as train_script include "t5x/configs/runs/finetune.gin" include "gins/t0.gin" # This overrides some default config in `t5x/configs/runs/finetune.gin` +include "gins/restore_pretrained.gin" # for loading from checkpoints TASK_FEATURE_LENGTHS = {"inputs": 1024, "hyper_inputs": 512, "task_names": 1, "targets": 256} MIXTURE_OR_TASK_NAME = "t0_train" diff --git a/hyper_task_descriptions/modeling/hyper_transformer.py b/hyper_task_descriptions/modeling/hyper_transformer.py index 5bb3baa3..735d3305 100644 --- a/hyper_task_descriptions/modeling/hyper_transformer.py +++ b/hyper_task_descriptions/modeling/hyper_transformer.py @@ -339,8 +339,7 @@ def _compute_logits( def _compute_logits_from_slice( self, - flat_ids: jnp.ndarray, - flat_cache: Mapping[str, jnp.ndarray], + decoding_state: decoding.DecodingState, params: PyTreeDef, encoded_inputs: jnp.ndarray, adaptations: Tuple[jnp.ndarray, ...], @@ -348,6 +347,8 @@ def _compute_logits_from_slice( max_decode_length: int, ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: """Token slice to logits from decoder model.""" + flat_ids = decoding_state.cur_token + flat_cache = decoding_state.cache # flat_ids: [batch * beam, seq_len=1] # cache is expanded inside beam_search to become flat_cache # flat_cache: [batch * beam, num_heads, depth_per_head, max_decode_len] @@ -710,7 +711,7 @@ def loss_fn( z_loss=self._z_loss, loss_normalizing_factor=loss_normalizing_factor, ) - loss += cos_loss * cosine_loss_multiplier # upweight since otherwise ce loss dominates + # loss += cos_loss * cosine_loss_multiplier # upweight since otherwise ce loss dominates metrics = self._compute_metrics( logits=logits, targets=batch["decoder_target_tokens"], diff --git a/hyper_task_descriptions/utils.py b/hyper_task_descriptions/utils.py index 599927a8..b1a3b05a 100644 --- a/hyper_task_descriptions/utils.py +++ b/hyper_task_descriptions/utils.py @@ -19,7 +19,10 @@ import re from typing import Any, Callable, Optional, Sequence, Tuple -from t5x import partitioning +import flax +import optax +from flax.core import frozen_dict +from t5x import optimizers, partitioning PartitionRule = Tuple[str, Optional[partitioning.PartitionSpec]] @@ -43,3 +46,48 @@ def _match_any(path, _): return any(regex.fullmatch(path) for regex in regexes) return _match_any + + +def inverse_match_any(regexes: Sequence[str]) -> Callable[[str, Any], bool]: + """Inverse of the above""" + regexes = tuple(re.compile(regex) for regex in regexes) + + def _match_any(path, _): + """False if path matches any regex in regexs, true otherwise.""" + return not any(regex.fullmatch(path) for regex in regexes) + + return _match_any + + +def flattened_traversal(fn): + """Returns function that is called with `(path, param)` instead of pytree.""" + + def mask(tree): + flat = flax.traverse_util.flatten_dict(tree, sep="/") + masked_tree = flax.traverse_util.unflatten_dict( + {k: fn(k, v) for k, v in flat.items()}, sep="/" + ) + return frozen_dict.freeze(masked_tree) + + return mask + + +def match_any_optax(regexes: Sequence[str]) -> Callable[[str, Any], bool]: + regexes = tuple(re.compile(regex) for regex in regexes) + label_fn = flattened_traversal( + lambda path, _: "train" if any(regex.fullmatch(path) for regex in regexes) else "freeze" + ) + return label_fn + + +# inverse match, mainly for adamw weight decay - see partial_adamw.gin for an example of how this is applied +def match_any_optax_inverse(regexes: Sequence[str]) -> Callable[[str, Any], bool]: + regexes = tuple(re.compile(regex) for regex in regexes) + label_fn = flattened_traversal( + lambda path, _: "freeze" if any(regex.fullmatch(path) for regex in regexes) else "train" + ) + return label_fn + + +# t5x doesnt wrap this but i need it +multi_transform = optimizers.wrap_optax_optimizer(optax.multi_transform) diff --git a/scripts/t0_eval.sh b/scripts/t0_eval.sh index 6eec02ac..fb49566e 100755 --- a/scripts/t0_eval.sh +++ b/scripts/t0_eval.sh @@ -3,9 +3,9 @@ EXPERIMENT_NAME=$1 CHECKPOINT_NAME=$2 # model checkpoint location -MODEL_DIR="gs://hamishi-tpu-bucket/${EXPERIMENT_NAME}/model/${CHECKPOINT_NAME}" +MODEL_DIR="gs://hamishi-us-bucket/${EXPERIMENT_NAME}/model/${CHECKPOINT_NAME}" # where to put eval results -EVAL_OUTPUT_DIR="gs://hamishi-tpu-bucket/${EXPERIMENT_NAME}/eval" +EVAL_OUTPUT_DIR="gs://hamishi-us-bucket/${EXPERIMENT_NAME}/eval" # we go offline to avoid constant calls to get basic info (happens even when cached) # for your first run, you will probably need to run all these calls :( diff --git a/scripts/t0_reg_eval.sh b/scripts/t0_reg_eval.sh index 65acd6db..8b0b3ce6 100755 --- a/scripts/t0_reg_eval.sh +++ b/scripts/t0_reg_eval.sh @@ -3,7 +3,7 @@ MODEL_DIR=$1 SAVE_DIR=$2 # where to put eval results -EVAL_OUTPUT_DIR="gs://hamishi-tpu-bucket/${SAVE_DIR}" +EVAL_OUTPUT_DIR="gs://hamishi-us-bucket/${SAVE_DIR}" # we go offline to avoid constant calls to get basic info (happens even when cached) # for your first run, you will probably need to run all these calls :( diff --git a/scripts/t0_reg_train.sh b/scripts/t0_reg_train.sh index 3ffdbac5..259846ca 100755 --- a/scripts/t0_reg_train.sh +++ b/scripts/t0_reg_train.sh @@ -15,4 +15,4 @@ HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.train \ --gin.TRAIN_STEPS=1212200 \ --gin.partitioning.PjitPartitioner.num_partitions=8 \ --gin.INITIAL_CHECKPOINT_PATH=\"gs://hamishi-us-bucket/t0_3b_further_train/model/checkpoint_1126000\" \ - --tfds_data_dir="gs://hamishi-tpu-bucket/t0_data/data" + --tfds_data_dir="gs://hamishi-us-bucket/t0_data/data" diff --git a/scripts/tpu_setup.sh b/scripts/tpu_setup.sh index df50ee43..30784772 100755 --- a/scripts/tpu_setup.sh +++ b/scripts/tpu_setup.sh @@ -8,6 +8,9 @@ python3 -m pip install promptsource # i use a new feature in t5.data python3 -m pip uninstall -y t5 python3 -m pip install git+https://github.com/google-research/text-to-text-transfer-transformer.git +# use a compatible version of optax +python3 -m pip uninstall -y optax +python3 -m pip install optax==0.1.2 # custom fixed seqio python3 -m pip uninstall -y seqio seqio-nightly python3 -m pip install git+https://github.com/hamishivi/seqio.git diff --git a/scripts/t0_train.sh b/scripts/train_from_t0.sh similarity index 87% rename from scripts/t0_train.sh rename to scripts/train_from_t0.sh index 258a0fec..5cd79b17 100755 --- a/scripts/t0_train.sh +++ b/scripts/train_from_t0.sh @@ -9,10 +9,10 @@ MODEL_DIR="gs://hamishi-us-bucket/${EXPERIMENT_NAME}/model" HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.train \ --gin_search_paths=gins \ --gin_file="hyper_xl.gin" \ - --gin_file="partial_train.gin" \ --gin_file="t0_train.gin" \ + --gin_file="partial_train_adam.gin" \ --gin.MODEL_DIR=\"${MODEL_DIR}\" \ --gin.TRAIN_STEPS=1212200 \ --gin.partitioning.PjitPartitioner.num_partitions=8 \ --gin.INITIAL_CHECKPOINT_PATH=\"gs://hamishi-us-bucket/t0_3b_further_train/model/checkpoint_1126000\" \ - --tfds_data_dir="gs://hamishi-tpu-bucket/t0_data/data" + --tfds_data_dir="gs://hamishi-us-bucket/t0_data/data" diff --git a/scripts/train_from_t5.sh b/scripts/train_from_t5.sh new file mode 100755 index 00000000..b124ddfc --- /dev/null +++ b/scripts/train_from_t5.sh @@ -0,0 +1,18 @@ +# name of experiment folder +EXPERIMENT_NAME=$1 + +# where model will be saved +MODEL_DIR="gs://hamishi-us-bucket/${EXPERIMENT_NAME}/model" + +# we go offline to avoid constant calls to get basic info (happens even when cached) +# for your first run, you will probably need to run all these calls :( +HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.train \ + --gin_search_paths=gins \ + --gin_file="hyper_xl.gin" \ + --gin_file="t0_train.gin" \ + --gin_file="partial_train_adam.gin" \ + --gin.MODEL_DIR=\"${MODEL_DIR}\" \ + --gin.TRAIN_STEPS=1212200 \ + --gin.partitioning.PjitPartitioner.num_partitions=8 \ + --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000\" \ + --tfds_data_dir="gs://hamishi-us-bucket/t0_data/data" diff --git a/tests/modeling/hyper_transformer_test.py b/tests/modeling/hyper_transformer_test.py index 6bdb81b9..e9c055d5 100644 --- a/tests/modeling/hyper_transformer_test.py +++ b/tests/modeling/hyper_transformer_test.py @@ -646,7 +646,9 @@ def mock_init(self): self.assertLen(tokens_to_logits_mock.call_args_list, max_decode_len) for tokens_call in tokens_to_logits_mock.call_args_list: # Inputs: [B * Be, 1] - inputs, cache = tokens_call[0] + decoding_state = tokens_call[0][0] + inputs = decoding_state.cur_token + cache = decoding_state.cache cache = flax.core.unfreeze(cache) # Cache: [B * Be, 1] * #Layers cache_keys = [