diff --git a/examples/lm1b/configs/default.py b/examples/lm1b/configs/default.py index 9405f76bb..8abaabd6d 100644 --- a/examples/lm1b/configs/default.py +++ b/examples/lm1b/configs/default.py @@ -30,11 +30,11 @@ def get_config(): config.max_corpus_chars = 10**7 # Name of TFDS translation dataset to use. - config.dataset_name = "lm1b" + config.dataset_name = 'lm1b' # Optional name of TFDS translation dataset to use for evaluation. - config.eval_dataset_name = "lm1b" - config.eval_split = "test" + config.eval_dataset_name = 'lm1b' + config.eval_split = 'test' # Per device batch size for training. config.per_device_batch_size = 32 @@ -114,7 +114,44 @@ def get_config(): # Integer for PRNG random seed. config.seed = 0 - # Prompt for language model sampling. - config.prompts = "I love to " + # Prompt for language model sampling, + # taken from MaxText (https://github.com/google/maxtext/blob/main/MaxText/configs/base.yml). + config.prompts = 'I love to ' + + # Parallelism + config.mesh_axes = ['data', 'fsdp', 'tensor'] + config.logical_axis_rules = [ + ['activation_batch', ['data', 'fsdp']], + ['activation_length', ['data', 'fsdp']], + ['activation_embed', 'tensor'], + ['activation_mlp', 'tensor'], + ['activation_heads', 'tensor'], + ['activation_kv', 'tensor'], + ['activation_vocab', 'tensor'], + ['mlp', 'tensor'], + ['vocab', 'tensor'], + ['embed', 'fsdp'], + ['heads', 'tensor'], + ] + config.full_sharding = ['data', 'fsdp', 'tensor'] + config.data_sharding = ['data'] + + # One axis for each parallelism type may hold a placeholder (-1) + # value to auto-shard based on available slices and devices. + # By default, product of the DCN axes should equal number of slices + # and product of the ICI axes should equal number of devices per slice. + # ICI (Inter-Chip Interconnection): A high-speed connection between + # sets of TPU chips, which form the TPU network. + # DCN (Data Center Network): A connection between the TPU networks; + # not as fast as ICI. + # ICI has around 100x the bandwidth of DCN, but it is not a general + # purpose connection, which is why DCN is necessary for scaling to + # extremely large ML models. + config.dcn_data_parallelism = -1 # recommended DCN axis to be auto-sharded + config.dcn_fsdp_parallelism = 1 + config.dcn_tensor_parallelism = 1 + config.ici_data_parallelism = 1 + config.ici_fsdp_parallelism = -1 # recommended ICI axis to be auto-sharded + config.ici_tensor_parallelism = 1 return config diff --git a/examples/lm1b/models.py b/examples/lm1b/models.py index 39842f2e7..94ce043bc 100644 --- a/examples/lm1b/models.py +++ b/examples/lm1b/models.py @@ -186,8 +186,10 @@ def __call__(self, inputs): x = nn.Dense( config.mlp_dim, dtype=config.dtype, - kernel_init=config.kernel_init, - bias_init=config.bias_init, + kernel_init=nn.with_logical_partitioning( + config.kernel_init, ('embed', 'mlp') + ), + bias_init=nn.with_logical_partitioning(config.bias_init, ('mlp',)), )(inputs) x = nn.relu(x) x = nn.Dropout(rate=config.dropout_rate)( @@ -196,8 +198,10 @@ def __call__(self, inputs): output = nn.Dense( actual_out_dim, dtype=config.dtype, - kernel_init=config.kernel_init, - bias_init=config.bias_init, + kernel_init=nn.with_logical_partitioning( + config.kernel_init, ('mlp', 'embed') + ), + bias_init=nn.with_logical_partitioning(config.bias_init, ('embed',)), )(x) output = nn.Dropout(rate=config.dropout_rate)( output, deterministic=config.deterministic @@ -230,13 +234,23 @@ def __call__(self, inputs, decoder_mask=None, encoder_decoder_mask=None): # Decoder block. assert inputs.ndim == 3 - x = nn.LayerNorm(dtype=config.dtype)(inputs) + x = nn.LayerNorm( + dtype=config.dtype, + bias_init=nn.with_logical_partitioning( + nn.initializers.zeros, ('embed',) + ), + scale_init=nn.with_logical_partitioning( + nn.initializers.ones, ('embed',) + ), + )(inputs) x = nn.SelfAttention( num_heads=config.num_heads, dtype=config.dtype, qkv_features=config.qkv_dim, - kernel_init=config.kernel_init, - bias_init=config.bias_init, + kernel_init=nn.with_logical_partitioning( + config.kernel_init, ('embed', 'kv') + ), + bias_init=nn.with_logical_partitioning(config.bias_init, ('embed',)), use_bias=False, broadcast_dropout=False, dropout_rate=config.attention_dropout_rate, @@ -249,7 +263,15 @@ def __call__(self, inputs, decoder_mask=None, encoder_decoder_mask=None): x = x + inputs # MLP block. - z = nn.LayerNorm(dtype=config.dtype)(x) + z = nn.LayerNorm( + dtype=config.dtype, + bias_init=nn.with_logical_partitioning( + nn.initializers.zeros, ('embed',) + ), + scale_init=nn.with_logical_partitioning( + nn.initializers.ones, ('embed',) + ), + )(x) z = MlpBlock(config=config)(z) return x + z @@ -296,7 +318,9 @@ def __call__( output_embed = nn.Embed( num_embeddings=config.output_vocab_size, features=config.emb_dim, - embedding_init=nn.initializers.normal(stddev=1.0), + embedding_init=nn.with_logical_partitioning( + nn.initializers.normal(stddev=1.0), ('vocab', 'embed') + ), ) else: output_embed = self.shared_embedding @@ -319,7 +343,16 @@ def __call__( y = EncoderDecoder1DBlock( config=config, name=f'encoderdecoderblock_{lyr}' )(y, decoder_mask=decoder_mask, encoder_decoder_mask=encoder_decoder_mask) - y = nn.LayerNorm(dtype=config.dtype, name='encoderdecoder_norm')(y) + y = nn.LayerNorm( + dtype=config.dtype, + name='encoderdecoder_norm', + bias_init=nn.with_logical_partitioning( + nn.initializers.zeros, ('embed',) + ), + scale_init=nn.with_logical_partitioning( + nn.initializers.ones, ('embed',) + ), + )(y) # Decoded Logits if config.logits_via_embedding: @@ -331,8 +364,10 @@ def __call__( logits = nn.Dense( config.output_vocab_size, dtype=config.dtype, - kernel_init=config.kernel_init, - bias_init=config.bias_init, + kernel_init=nn.with_logical_partitioning( + config.kernel_init, ('embed', 'vocab') + ), + bias_init=nn.with_logical_partitioning(config.bias_init, ('vocab',)), name='logitdense', )(y) return logits diff --git a/examples/lm1b/requirements.txt b/examples/lm1b/requirements.txt index 91a8469fa..8ebc88977 100644 --- a/examples/lm1b/requirements.txt +++ b/examples/lm1b/requirements.txt @@ -1,13 +1,13 @@ -absl-py==1.0.0 -clu==0.0.6 -flax==0.4.1 -jax==0.3.4 +absl-py==1.4.0 +clu==0.0.9 +flax==0.6.11 +jax==0.4.13 --find-links https://storage.googleapis.com/jax-releases/jax_releases.html -jaxlib==0.3.2+cuda11.cudnn82 # Make sure CUDA version matches the base image. -ml-collections==0.1.0 -numpy==1.22.0 -optax==0.1.0 -sentencepiece==0.1.96 -tensorflow==2.11.1 -tensorflow-datasets==4.4.0 -tensorflow-text==2.8.1 +jaxlib==0.4.13+cuda11.cudnn82 # Make sure CUDA version matches the base image. +ml-collections==0.1.1 +numpy==1.24.3 +optax==0.1.5 +sentencepiece==0.1.99 +tensorflow==2.13.0 +tensorflow-datasets==4.9.2 +tensorflow-text==2.13.0 diff --git a/examples/lm1b/temperature_sampler_test.py b/examples/lm1b/temperature_sampler_test.py index 6ccb079c7..a0c7f46ec 100644 --- a/examples/lm1b/temperature_sampler_test.py +++ b/examples/lm1b/temperature_sampler_test.py @@ -28,7 +28,7 @@ class TestTemperatureSampler(absltest.TestCase): def test_temperature_sampler(self): tokens = jnp.array([[5, 0, 0, 0]], dtype=jnp.int32) cache = None - key = jax.random.key(0) + key = jax.random.PRNGKey(0) def tokens_to_logits(tokens, cache): jax.debug.print('tokens: {}', tokens) diff --git a/examples/lm1b/train.py b/examples/lm1b/train.py index ffd8060a9..33507e43c 100644 --- a/examples/lm1b/train.py +++ b/examples/lm1b/train.py @@ -21,20 +21,18 @@ # pytype: disable=attribute-error import collections -import functools import os from absl import logging from clu import metric_writers from clu import periodic_actions -from flax import jax_utils from flax import linen as nn from flax.training import checkpoints from flax.training import common_utils -from flax.training import train_state import jax from jax import random import jax.numpy as jnp +from jax.sharding import PartitionSpec as P, Mesh, NamedSharding import ml_collections import numpy as np import optax @@ -43,6 +41,7 @@ import input_pipeline import models import temperature_sampler +import utils def rsqrt_schedule( @@ -161,7 +160,6 @@ def compute_metrics(logits, labels, weights, label_smoothing=0.0): "accuracy": acc, "denominator": weight_sum, } - metrics = jax.lax.psum(metrics, axis_name="batch") return metrics @@ -212,7 +210,6 @@ def loss_fn(params): lr = learning_rate_fn(step) grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (_, logits), grads = grad_fn(state.params) - grads = jax.lax.pmean(grads, "batch") new_state = state.apply_gradients(grads=grads) metrics = compute_metrics(logits, inputs, weights) metrics["learning_rate"] = lr @@ -235,7 +232,7 @@ def predict_step( """Predict language model on a batch.""" target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:] initial_variables = models.TransformerLM(config).init( - jax.random.key(0), jnp.ones(target_shape, config.dtype) + jax.random.PRNGKey(0), jnp.ones(target_shape, config.dtype) ) cache = initial_variables["cache"] @@ -302,7 +299,12 @@ def tohost(x): def evaluate( - *, p_eval_step, params, eval_ds: tf.data.Dataset, num_eval_steps: int + *, + jit_eval_step, + params, + eval_ds: tf.data.Dataset, + num_eval_steps: int, + config, ): """Evaluate the target an return a dictionary with the metrics.""" logging.info("Gathering evaluation metrics.") @@ -310,10 +312,9 @@ def evaluate( eval_iter = iter(eval_ds) # pytype: disable=wrong-arg-types for _, eval_batch in zip(range(num_eval_steps), eval_iter): eval_batch = jax.tree_util.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access - eval_batch = common_utils.shard(eval_batch) - metrics = p_eval_step(params, eval_batch) + metrics = jit_eval_step(params, eval_batch, config) eval_metrics.append(metrics) - eval_metrics = common_utils.get_metrics(eval_metrics) + eval_metrics = common_utils.stack_forest(eval_metrics) eval_metrics_sums = jax.tree_util.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop("denominator") eval_summary = jax.tree_util.tree_map( @@ -325,13 +326,14 @@ def evaluate( def generate_prediction( *, - p_pred_step, + jit_pred_step, params, tokenized_prompts, eos_id, inference_rng, decode_tokens, - max_predict_length: int, + config, + predict_config, ): """Generate text from the prompt.""" n_devices = jax.local_device_count() @@ -352,8 +354,15 @@ def generate_prediction( inference_rng, sub_rng = random.split(inference_rng) inference_rngs = random.split(sub_rng, n_devices) - predicted = p_pred_step( - pred_batch, params, inference_rngs, eos_id, max_predict_length + predicted = jit_pred_step( + pred_batch, + params, + inference_rngs, + eos_id, + config.max_predict_length, + predict_config, + config.sampling_temperature, + config.sampling_top_k, ) predicted = tohost(predicted) # Iterate through non-padding examples of batch. @@ -436,16 +445,16 @@ def encode_strings(strs, max_len): eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) + # Mesh definition + devices_array = utils.create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + start_step = 0 - rng = jax.random.key(config.seed) + rng = jax.random.PRNGKey(config.seed) rng, init_rng = jax.random.split(rng) rng, inference_rng = random.split(rng) - input_shape = (config.per_device_batch_size, config.max_target_length) m = models.TransformerLM(eval_config) - initial_variables = jax.jit(m.init)( - init_rng, jnp.ones(input_shape, jnp.float32) - ) learning_rate_fn = create_learning_rate_schedule( learning_rate=config.learning_rate, warmup_steps=config.warmup_steps @@ -458,11 +467,12 @@ def encode_strings(strs, max_len): eps=1e-9, weight_decay=config.weight_decay, ) - state = train_state.TrainState.create( - apply_fn=m.apply, params=initial_variables["params"], tx=optimizer + + state, state_mesh_annotations = utils.setup_initial_state( + m, optimizer, config, init_rng, mesh ) - # We access model params only from optimizer below. - del initial_variables + data_sharding = NamedSharding(mesh, P(config.data_sharding)) + full_sharding = NamedSharding(mesh, P(config.full_sharding)) if config.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. @@ -476,39 +486,60 @@ def encode_strings(strs, max_len): if start_step == 0: writer.write_hparams(dict(config)) - # Replicate optimizer. - state = jax_utils.replicate(state) - # compile multidevice versions of train/eval/predict step fn. - p_train_step = jax.pmap( - functools.partial( - train_step, config=train_config, learning_rate_fn=learning_rate_fn - ), - axis_name="batch", - donate_argnums=(0,), - ) # pytype: disable=wrong-arg-types - p_eval_step = jax.pmap( - functools.partial(eval_step, config=eval_config), axis_name="batch" + jit_train_step = jax.jit( + train_step, + in_shardings=( + state_mesh_annotations, + full_sharding, + None, + ), # type: ignore + out_shardings=(state_mesh_annotations, None), # type: ignore + static_argnums=(2, 3, 4), + donate_argnums=0, + ) + + jit_eval_step = jax.jit( + eval_step, + in_shardings=( + state_mesh_annotations.params, + full_sharding, + ), # type: ignore + out_shardings=None, # type: ignore + static_argnums=(2, 3), ) - p_pred_step = jax.pmap( - functools.partial( + # Since the inputs and rngkey args for predict_step will be batched, + # we must vmap them, otherwise the global arrays will be seen in each device + jit_pred_step = jax.jit( + jax.vmap( predict_step, - config=predict_config, - temperature=config.sampling_temperature, - top_k=config.sampling_top_k, + in_axes=( + 0, + jax.tree_map(lambda x: None, state.params), + 0, + None, + None, + jax.tree_map(lambda x: None, predict_config), + None, + None, + ), ), - axis_name="batch", - static_broadcasted_argnums=(3, 4), - ) # eos token, max_length are constant + in_shardings=( + data_sharding, + state_mesh_annotations.params, + data_sharding, + ), # type: ignore + out_shardings=data_sharding, # type: ignore + static_argnums=(3, 4, 5, 6, 7), + ) # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap"d training update for performance. - dropout_rngs = jax.random.split(rng, jax.local_device_count()) - del rng + dropout_rngs = rng logging.info("Starting training loop.") hooks = [] @@ -527,10 +558,11 @@ def encode_strings(strs, max_len): # Shard data to devices and do a training step. with jax.profiler.StepTraceAnnotation("train", step_num=step): - batch = common_utils.shard( - jax.tree_util.tree_map(np.asarray, next(train_iter)) + batch = next(train_iter) + batch = jax.tree_map(lambda x: jnp.array(x), batch) + state, metrics = jit_train_step( + state, batch, train_config, learning_rate_fn, 0.0, dropout_rngs ) - state, metrics = p_train_step(state, batch, dropout_rng=dropout_rngs) train_metrics.append(metrics) # Quick indication that training is happening. @@ -542,7 +574,7 @@ def encode_strings(strs, max_len): if step % config.eval_every_steps == 0 or is_last_step: with report_progress.timed("training_metrics"): logging.info("Gathering training metrics.") - train_metrics = common_utils.get_metrics(train_metrics) + train_metrics = common_utils.stack_forest(train_metrics) lr = train_metrics.pop("learning_rate").mean() metrics_sums = jax.tree_util.tree_map(jnp.sum, train_metrics) denominator = metrics_sums.pop("denominator") @@ -559,10 +591,11 @@ def encode_strings(strs, max_len): with report_progress.timed("eval"): eval_results = evaluate( - p_eval_step=p_eval_step, + jit_eval_step=jit_eval_step, params=state.params, eval_ds=eval_ds, num_eval_steps=config.num_eval_steps, + config=eval_config, ) # (clipped) perplexity after averaging log-perplexitie eval_results["perplexity"] = jnp.clip( @@ -574,13 +607,14 @@ def encode_strings(strs, max_len): with report_progress.timed("generate_text"): exemplars = generate_prediction( - p_pred_step=p_pred_step, + jit_pred_step=jit_pred_step, params=state.params, tokenized_prompts=tokenized_prompts, eos_id=eos_id, inference_rng=inference_rng, decode_tokens=decode_tokens, - max_predict_length=config.max_predict_length, + config=config, + predict_config=predict_config, ) writer.write_texts(step, {"samples": exemplars}) @@ -591,6 +625,4 @@ def encode_strings(strs, max_len): if config.save_checkpoints and save_checkpoint: logging.info("Saving checkpoint step %d.", step) with report_progress.timed("checkpoint"): - checkpoints.save_checkpoint_multiprocess( - workdir, jax_utils.unreplicate(state), step - ) + checkpoints.save_checkpoint_multiprocess(workdir, state, step) diff --git a/examples/lm1b/utils.py b/examples/lm1b/utils.py new file mode 100644 index 000000000..af3eea128 --- /dev/null +++ b/examples/lm1b/utils.py @@ -0,0 +1,168 @@ +# Copied over from MaxText (https://github.com/google/maxtext/blob/main/MaxText/max_utils.py). + +import functools +import logging + +import numpy as np +import flax.linen as nn +from flax.linen import partitioning as nn_partitioning +from flax.training import train_state +import jax +import jax.numpy as jnp +from jax.experimental import mesh_utils + + +# Mesh utils. +# ----------------------------------------------------------------------------- + + +def create_device_mesh(config): + """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas.""" + devices = jax.devices() + num_devices = len(devices) + try: + num_slices = 1 + max([d.slice_index for d in devices]) + except: + num_slices = 1 + num_devices_per_slice = num_devices // num_slices + logging.info(f"Devices: {devices}") + logging.info(f"Number of devices: {num_devices}") + + multi_slice_env = hasattr(jax.devices()[0], "slice_index") + + dcn_parallelism = [ + config.dcn_data_parallelism, + config.dcn_fsdp_parallelism, + config.dcn_tensor_parallelism, + ] + ici_parallelism = [ + config.ici_data_parallelism, + config.ici_fsdp_parallelism, + config.ici_tensor_parallelism, + ] + + # Find possible unspecified parallelisms + dcn_parallelism = fill_unspecified_mesh_axes( + dcn_parallelism, num_slices, "DCN" + ) + ici_parallelism = fill_unspecified_mesh_axes( + ici_parallelism, num_devices_per_slice, "ICI" + ) + + if multi_slice_env: + mesh = mesh_utils.create_hybrid_device_mesh( + ici_parallelism, dcn_parallelism + ) + else: + mesh = mesh_utils.create_device_mesh(ici_parallelism) + + logging.info(f"Decided on mesh: {mesh}") + logging.info(f"Mesh shape: {mesh.shape}") + + return mesh + + +def fill_unspecified_mesh_axes( + parallelism_vals, target_product, parallelism_type +): + """Evaluates unspecified DCN/ICI parallelism values""" + if -1 in parallelism_vals: + assert parallelism_vals.count(-1) == 1, ( + f"Found unspecified values (-1) for more than one {parallelism_type} " + " parallelism axis. At most one axis can be unspecified." + ) + + determined_val = target_product / np.product(parallelism_vals) * -1 + + assert determined_val >= 1 and determined_val.is_integer, ( + "Unspecified value unable to be determined with the given " + f" {parallelism_type} parallelism values" + ) + + parallelism_vals[parallelism_vals.index(-1)] = int(determined_val) + + target_type = "slices" if parallelism_type == "DCN" else "devices per slice" + + assert np.product(parallelism_vals) == target_product, ( + f"Number of {target_type} {target_product} does not match the product" + f" of the {parallelism_type} parallelism {np.product(parallelism_vals)}" + ) + + return parallelism_vals + + +# State initialization utils. +# ----------------------------------------------------------------------------- + + +def unbox_logicallypartioned_trainstate( + boxed_train_state: train_state.TrainState, +): + """Unboxes the flax.LogicallyPartitioned pieces in a train state. + + Args: + boxed_train_state: a train state that includes LogicallyPartitioned + leaves. + Returns: + a TrainState where all all LogicallyPartitioned leaves have been unboxed. + """ + return jax.tree_util.tree_map( + lambda x: x.unbox() if isinstance(x, nn.spmd.LogicallyPartitioned) else x, + boxed_train_state, + is_leaf=lambda k: isinstance(k, nn.spmd.LogicallyPartitioned), + ) + + +def init_train_state(model, tx, config, key): + """ + We pass in "static" objects like model, tx, config as JAX compares them by + object hash, and instantiating them inside causes pjit top-level annotations + to fail to match as pytree prefixes if we re-instantiate. + + Args: model, tx, config, key + """ + input_shape = (config.per_device_batch_size, config.max_target_length) + initial_variables = jax.jit(model.init)( + key, jnp.ones(input_shape, jnp.float32) + ) + + state = train_state.TrainState.create( + apply_fn=model.apply, params=initial_variables["params"], tx=tx + ) + return state + + +def setup_initial_state(model, tx, config, rng, mesh): + """We initialize the model and optimizer state, and optionally load from a + checkpoint as necessary. + + Args: + model: the flax model to initialize + tx: the optax.GradientTransformation + config: config object + rng: jax.prng key + mesh: jax.devices() mesh + + Returns: + state: the initialized train state + state_mesh_annotations: the mesh annotations for the train state + """ + init_train_state_partial = functools.partial( + init_train_state, model, tx, config + ) + abstract_state = jax.eval_shape(init_train_state_partial, rng) + state_logical_annotations = nn.get_partition_spec(abstract_state) + + # Initialization + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + state_mesh_annotations = nn.logical_to_mesh_sharding( + state_logical_annotations, mesh, config.logical_axis_rules + ) + state = jax.jit( + init_train_state_partial, + in_shardings=None, # type: ignore + out_shardings=state_mesh_annotations, + )(rng) + + state = unbox_logicallypartioned_trainstate(state) + return state, state_mesh_annotations