diff --git a/.gitignore b/.gitignore index 7d35f0ccc..916a29ff4 100644 --- a/.gitignore +++ b/.gitignore @@ -25,4 +25,4 @@ scoring/plots/ !scoring/test_data/experiment_dir/study_0/mnist_jax/trial_0/eval_measurements.csv !scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv -algoperf/_version.py +algoperf/_version.py \ No newline at end of file diff --git a/algoperf/checkpoint_utils.py b/algoperf/checkpoint_utils.py index 2c8441d9c..af05111cd 100644 --- a/algoperf/checkpoint_utils.py +++ b/algoperf/checkpoint_utils.py @@ -5,14 +5,16 @@ """ import os -from typing import Sequence, Tuple +from typing import Optional, Sequence, Tuple import numpy as np +import orbax.checkpoint as ocp import torch from absl import logging from flax import jax_utils from flax.training import checkpoints as flax_checkpoints from flax.training.checkpoints import latest_checkpoint +from orbax.checkpoint.type_handlers import NumpyHandler from tensorflow.io import gfile # pytype: disable=import-error from algoperf import spec @@ -30,6 +32,51 @@ ] +class BoolHandler(NumpyHandler): + """ + An implementation of TypeHandler for np.bool_ that inherits from NumpyHandler. + It works by treating the scalar as a 0-dimensional array. + """ + + def typestr(self) -> str: + """Unique string identifier for this handler.""" + return 'np.bool_' + + async def serialize( + self, + values: Sequence[np.bool_], + infos: Sequence, + args: Optional[Sequence[ocp.SaveArgs]] = None, + ): + """ + Serializes a sequence of np.bool_ scalars by first converting them + to 0-dim numpy arrays and then calling the parent NumpyHandler. + """ + # Convert each scalar np.bool_ to a 0-dimensional np.ndarray + array_values = [np.asarray(v, dtype=np.bool_) for v in values] + # Use the parent class's robust serialization logic + return await super().serialize(array_values, infos, args) + + async def deserialize( + self, + infos: Sequence, + args: Optional[Sequence[ocp.RestoreArgs]] = None, + ) -> Sequence[np.bool_]: + """ + Deserializes into a sequence of np.bool_ scalars by calling the + parent handler and then converting the resulting 0-dim arrays. + """ + # Parent deserialize will return a sequence of 0-dimensional np.ndarray + results = await super().deserialize(infos, args) + + # Convert each 0-d array back to an np.bool_ scalar using .item() + scalar_results = [np.bool_(r.item()) for r in results] + return scalar_results + + +ocp.type_handlers.register_type_handler(np.bool_, BoolHandler(), override=True) + + def maybe_restore_checkpoint( framework: str, optimizer_state: spec.OptimizerState, diff --git a/algoperf/param_utils.py b/algoperf/param_utils.py index 908ef0f27..26a351bb4 100644 --- a/algoperf/param_utils.py +++ b/algoperf/param_utils.py @@ -44,6 +44,8 @@ def pytorch_param_types( param_types[name] = spec.ParameterType.ATTENTION_BIAS elif 'in_proj' in name: param_types[name] = spec.ParameterType.ATTENTION_QKV + elif 'qkv' in name: + param_types[name] = spec.ParameterType.ATTENTION_QKV elif 'kv_proj' in name: param_types[name] = spec.ParameterType.ATTENTION_KV elif 'k_proj' in name or 'key' in name: diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py index af09e67fc..e24b0f141 100644 --- a/algoperf/pytorch_utils.py +++ b/algoperf/pytorch_utils.py @@ -27,7 +27,9 @@ def pytorch_setup() -> Tuple[bool, int, torch.device, int]: return use_pytorch_ddp, rank, device, n_gpus -def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None: +def pytorch_init( + use_pytorch_ddp: bool, rank: int, profiler: Profiler, limit_tf_threads=True +) -> None: # Make sure no GPU memory is preallocated to Jax. os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' # Only use CPU for Jax to avoid memory issues. @@ -39,7 +41,7 @@ def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None: if use_pytorch_ddp: # Avoid tf input pipeline creating too many threads. - if rank != 0: + if rank != 0 and limit_tf_threads: tf.config.threading.set_intra_op_parallelism_threads(1) tf.config.threading.set_inter_op_parallelism_threads(1) diff --git a/algoperf/random_utils.py b/algoperf/random_utils.py index 1dc773e80..07efa2bdf 100644 --- a/algoperf/random_utils.py +++ b/algoperf/random_utils.py @@ -35,13 +35,13 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType: def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32) + new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.uint32) return [new_seed, data] def _split(seed: SeedType, num: int = 2) -> SeedType: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2]) + return rng.randint(MIN_INT32, MAX_INT32, dtype=np.uint32, size=[num, 2]) def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name diff --git a/algoperf/workloads/criteo1tb/workload.py b/algoperf/workloads/criteo1tb/workload.py index 2cb7e5450..4d2196cd5 100644 --- a/algoperf/workloads/criteo1tb/workload.py +++ b/algoperf/workloads/criteo1tb/workload.py @@ -95,11 +95,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 7_703 # ~2.1 hours. + return 8_915 # ~2.4 hours. @property def eval_period_time_sec(self) -> int: - return 2 * 60 # 2 mins. + return 356 # approx 25 evals def _build_input_queue( self, diff --git a/algoperf/workloads/fastmri/workload.py b/algoperf/workloads/fastmri/workload.py index 0b1ecfaa1..b87dfc755 100644 --- a/algoperf/workloads/fastmri/workload.py +++ b/algoperf/workloads/fastmri/workload.py @@ -95,11 +95,11 @@ def accelerations(self): @property def max_allowed_runtime_sec(self) -> int: - return 4_430 # ~1.2 hours + return 2_745 # ~0.7 hours @property def eval_period_time_sec(self) -> int: - return 80 + return 110 # approx 25 evals @property def step_hint(self) -> int: diff --git a/algoperf/workloads/finewebedu_lm/__init__.py b/algoperf/workloads/finewebedu_lm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/__init__.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py new file mode 100644 index 000000000..d08e9b7bf --- /dev/null +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py @@ -0,0 +1,397 @@ +""" +Originally based on code from the NanoDO repository under the Apache 2.0 license: +https://github.com/google-deepmind/nanodo +""" + +import dataclasses +from functools import partial + +import jax +import jax.numpy as jnp +from flax import linen as nn + + +@dataclasses.dataclass +class ModelConfig: + """Hyper-parameters for Transformer decoder-only.""" + + model_dim: int # model/embed dim = qkv dim + num_heads: int # num attention heads + seq_len: int # max context/sequence length + num_layers: int # number of transformer block layers + vocab_size: int # vocab size + expanded_model_dim: int # FF inner dimension + multiple_of: int = 256 + rmsnorm_epsilon: float = 1e-6 + use_residual_scaling: bool = True + tie_embeddings: bool = True # Whether to tie input and output embed + qknorm_epsilon: float = 1e-6 + + dtype: jnp.dtype = jnp.float32 + attention_init: nn.initializers.Initializer = nn.initializers.normal( + stddev=0.02 + ) + linear_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) + embed_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) + + def __post_init__(self): + self.residual_init = nn.initializers.normal( + stddev=0.02 / jnp.sqrt(2 * self.num_layers) + ) + + +class Mlp(nn.Module): + """Multilayer perceptron with GLU activation.""" + + cfg: ModelConfig + + @nn.compact + def __call__(self, x_BxLxD: jax.Array): + cfg = self.cfg + linear = partial( + nn.Dense, kernel_init=cfg.linear_init, use_bias=False, dtype=cfg.dtype + ) + # Adjust hidden dimension to keep the number of parameters invariant to + # the activation function used since the GLU MLP has 3 * hidden_dim * D + # parameters instead of 2 * hidden_dim * D parameters + hidden_dim = cfg.expanded_model_dim * 2 / 3 + hidden_dim = cfg.multiple_of * ( + (cfg.expanded_model_dim + cfg.multiple_of - 1) // cfg.multiple_of + ) + # Double the hidden dimension for GLU + x_BxLx2F = linear(2 * hidden_dim)(x_BxLxD) + # Apply GLU activation + x_BxLxF = nn.glu(x_BxLx2F, axis=-1) + x_BxLxD = nn.Dense( + cfg.model_dim, + use_bias=False, + dtype=cfg.dtype, + kernel_init=cfg.residual_init + if cfg.use_residual_scaling + else cfg.linear_init, + )(x_BxLxF) + return x_BxLxD + + +@partial(jax.jit, static_argnums=(0, 1, 2)) +def init_rope(dim=256, seq_len=128, n_heads=4): + """Initialize rotary embeddings.""" + + def precompute_freqs_cis_jax(dim, end, theta=10000.0): + inv_freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2) / dim)) + t = jnp.arange(end) / 1.0 + freqs = jnp.outer(t, inv_freqs).astype(jnp.float32) + return jnp.stack( + [jnp.cos(freqs)[None, :, None, :], jnp.sin(freqs)[None, :, None, :]], + axis=3, + ) + + freqs_cis = precompute_freqs_cis_jax(dim // n_heads, seq_len, theta=500000) + return freqs_cis.transpose(0, 1, 2, 4, 3) + + +@jax.jit +def apply_rope(q, k, freqs_cis): + """Apply rotary embeddings to Q and K.""" + + def rotate_tensor(x): + # Split into real and imaginary parts + x_r2 = x.reshape(*x.shape[:-1], -1, 2) + L = x.shape[1] + freqs = freqs_cis[:, :L, :, :, :] + + # Apply rotation + rotated_x_r2 = jnp.stack( + [ + x_r2[..., 0] * freqs[..., 0] - x_r2[..., 1] * freqs[..., 1], + x_r2[..., 1] * freqs[..., 0] + x_r2[..., 0] * freqs[..., 1], + ], + axis=-1, + ) + + return rotated_x_r2.reshape(*x.shape) + + # Apply rotation to Q and K separately + rotated_q = rotate_tensor(q) + rotated_k = rotate_tensor(k) + + return rotated_q, rotated_k + + +class CausalAttn(nn.Module): + """Causal attention layer with rotary embeddings.""" + + cfg: ModelConfig + + def setup(self): + cfg = self.cfg + assert cfg.model_dim % cfg.num_heads == 0, ( + f'D {cfg.model_dim} not divisible by H {cfg.num_heads}' + ) + self.Dh = cfg.model_dim // cfg.num_heads + self.eps = cfg.qknorm_epsilon + + # Initialize rotary embeddings + self.freqs_cis = init_rope(cfg.model_dim, cfg.seq_len, cfg.num_heads) + + # Maps D -> (H, Dh) + self.multilinear = partial( + nn.DenseGeneral, + axis=-1, + features=(cfg.num_heads, self.Dh), + kernel_init=cfg.attention_init, + use_bias=False, + dtype=cfg.dtype, + ) + self.multilinear_query = self.multilinear(name='query') + self.multilinear_key = self.multilinear(name='key') + self.multilinear_value = self.multilinear(name='value') + # See Henry et al. (2020) "Query Key Normalization for Transformers" + seq_len = cfg.seq_len + attn_scale0 = jnp.log2(seq_len**2 - seq_len) + self.attn_scale = self.param( + 'attn_scale', nn.initializers.constant(attn_scale0), () + ) + self.output_projection = nn.DenseGeneral( + features=cfg.model_dim, + name='attn_out_proj', + # axis=(-2, -1), # + kernel_init=cfg.residual_init + if cfg.use_residual_scaling + else cfg.linear_init, + use_bias=False, + dtype=cfg.dtype, + ) + + def __call__(self, x_BxLxD: jax.Array): + cfg = self.cfg + + # Project inputs to Q, K, V + q_BxLxHxDh = self.multilinear_query(x_BxLxD) + k_BxLxHxDh = self.multilinear_key(x_BxLxD) + v_BxLxHxDh = self.multilinear_value(x_BxLxD) + + # Apply rotary embeddings to Q and K + q_BxLxHxDh, k_BxLxHxDh = apply_rope(q_BxLxHxDh, k_BxLxHxDh, self.freqs_cis) + + # Apply QK normalization + q_BxLxHxDh /= jnp.linalg.norm(q_BxLxHxDh, axis=-1, keepdims=True) + self.eps + k_BxLxHxDh /= jnp.linalg.norm(k_BxLxHxDh, axis=-1, keepdims=True) + self.eps + + # Compute attention scores + att_BxHxLxL = jnp.einsum('...qhd,...khd->...hqk', q_BxLxHxDh, k_BxLxHxDh) + + # Causal attention mask + L = x_BxLxD.shape[1] + mask_1x1xLxL = jnp.tril(jnp.ones((1, 1, L, L), dtype=jnp.bool_)) + + # Apply mask and softmax + _NEG_INF = jnp.finfo(cfg.dtype).min + att_BxHxLxL = jnp.where(mask_1x1xLxL, att_BxHxLxL, _NEG_INF) + att_BxHxLxL = ( + self.attn_scale * att_BxHxLxL + ) # Learned scaling factor for QK norm + att_BxHxLxL = jax.nn.softmax(att_BxHxLxL, axis=-1) + att_BxHxLxL = att_BxHxLxL.astype(cfg.dtype) + + # Compute attention output + out_BxLxHxDh = jnp.einsum('...hqk,...khd->...qhd', att_BxHxLxL, v_BxLxHxDh) + + # Reshape and project output + out_BxLxD = out_BxLxHxDh.reshape(*x_BxLxD.shape) + + # Output projection + out_BxLxD = self.output_projection(out_BxLxD) + + return out_BxLxD + + +class TBlock(nn.Module): + """Transformer Block.""" + + docfg: ModelConfig + + @nn.compact + def __call__(self, in_BxLxD: jax.Array): + cfg = self.docfg + + # x = x + attn( attn_norm(x) ) + x_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( + in_BxLxD + ) + x_BxLxD = CausalAttn(cfg)(x_BxLxD) + x_BxLxD += in_BxLxD + + # x = x + mlp( mlp_norm(x) ) + z_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( + x_BxLxD + ) + z_BxLxD = Mlp(cfg)(z_BxLxD) + + return x_BxLxD + z_BxLxD + + +class TransformerDo(nn.Module): + """Transformer decoder-only.""" + + docfg: ModelConfig + + def setup(self): + cfg = self.docfg + self.embed = nn.Embed( + num_embeddings=cfg.vocab_size, + features=cfg.model_dim, + embedding_init=cfg.embed_init, + ) + + self.blocks = [TBlock(cfg) for _ in range(cfg.num_layers)] + self.out_ln = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon) + + # Output projection - tied to input embeddings if configured + if cfg.tie_embeddings: + self.output_proj = lambda x: self.embed.attend(x.astype(jnp.float32)) + else: + self.output_proj = nn.Dense( + cfg.vocab_size, + kernel_init=cfg.embed_init, + dtype=cfg.dtype, + name='output_proj', + ) + + def __call__(self, y_BxL: jax.Array): + # For training on concatenated examples. + y_BxLxD = self.embed(y_BxL) + for block in self.blocks: + y_BxLxD = block(y_BxLxD) + y_BxLxD = self.out_ln(y_BxLxD) + logits_BxLxV = self.output_proj(y_BxLxD) + return logits_BxLxV + + def predict(self, y_BxL: jax.Array, k: int = 1): + """Generate k tokens autoregressively. + + Args: + y_BxL: Input token sequence of shape (batch_size, seq_len) + k: Number of tokens to predict + + Returns: + Tuple of (input_ids, predicted_ids) + """ + cfg = self.docfg + seq_len = y_BxL.shape[1] + + # Store original input + original_input = y_BxL + + # Make sure we don't exceed the model's context length + if seq_len + k > cfg.seq_len: + raise ValueError( + f"Total sequence length ({seq_len + k}) exceeds model's context length ({cfg.seq_len})" + ) + + # Generate k tokens autoregressively + for _ in range(k): + # Get logits for the entire sequence + logits = self(y_BxL) + + # Get the logits for the last token in each sequence + next_token_logits = logits[:, -1, :] + last_token_id = y_BxL[:, -1] + # Prevent predicting the same token consecutively + next_token_logits = next_token_logits.at[ + jnp.arange(len(last_token_id)), last_token_id + ].set(float('-inf')) + + # Get the most likely token + next_token = jnp.argmax(next_token_logits, axis=-1) + + # Append the predicted token to the sequence + y_BxL = jnp.concatenate([y_BxL, next_token[:, None]], axis=1) + + # Return original input and the k predicted tokens + return original_input, y_BxL[:, -k:] + + +# =========== Demo Code ========== + + +def main(): + """Create and run the DecoderOnly Transformer model.""" + # Initialize model configuration with smaller parameters for demo + B, L = (2, 128) # Batch size, sequence length + cfg = ModelConfig( + model_dim=128, + num_heads=4, + seq_len=L, + num_layers=2, + vocab_size=256, + expanded_model_dim=4 * 128, + ) + model = TransformerDo(cfg) + + # Print model info + print('\nModel Configuration:') + print(f' - Model dimension (D): {cfg.model_dim}') + print(f' - Number of heads (H): {cfg.num_heads}') + print(f' - Max sequence length (L): {cfg.seq_len}') + print(f' - Number of layers (N): {cfg.num_layers}') + print(f' - Vocabulary size (V): {cfg.vocab_size}') + print(f' - Feed forward dimension (F): {cfg.expanded_model_dim}') + + # Create random input tokens (simulated token IDs) + rng_key = jax.random.PRNGKey(42) + input_rng, init_rng = jax.random.split(rng_key) + + # Generate random token IDs (integers between 0 and vocab_size-1) + x_BxL = jax.random.randint( + input_rng, shape=(B, L), minval=0, maxval=cfg.vocab_size, dtype=jnp.int32 + ) + + # Initialize model parameters + print('\nInitializing model parameters...') + params = model.init(init_rng, x_BxL) + + # Print parameter count + param_count = sum(x.size for x in jax.tree_util.tree_leaves(params)) + print(f'Total parameters: {param_count:,}') + + # Make a prediction (forward pass) + print('\nRunning forward pass...') + logits = model.apply(params, x_BxL) + + # Print output shape and sample values + print( + f'\nOutput shape: {logits.shape} (batch_size, sequence_length, vocab_size)' + ) + print(f'Output data type: {logits.dtype}') + + # Print sample logits (first 5 positions of the first sequence) + print('\nSample logits (first sequence, first 5 positions, first 5 values):') + for position in range(min(5, L)): + print(f' Position {position}: {logits[0, position, :5]}') + + # Get predictions (token with highest logit at each position) + predictions = jnp.argmax(logits, axis=-1) + print('\nPredicted token IDs (first sequence, first 10 positions):') + print(predictions[0, :10]) + + # Test the predict function + print('\nTesting predict function...') + # Use a shorter + short_seq = x_BxL[:, :10] + print(f'Input sequence shape: {short_seq.shape}') + + # Predict 5 tokens + k = 5 + original, predicted = model.apply(params, short_seq, k, method=model.predict) + + # Get predictions (token with highest logit at each position) + predictions = jnp.argmax(logits, axis=-1) + print('\nPredicted token IDs (first sequence, first 10 positions):') + print(predictions[0, :10]) + + print('\nDone!') + + +if __name__ == '__main__': + main() diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py new file mode 100644 index 000000000..ee4cffbbc --- /dev/null +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py @@ -0,0 +1,169 @@ +"""LM workload implemented in Jax.""" + +from typing import Any, Dict, Optional, Tuple + +import jax +import jax.numpy as jnp + +from algoperf import jax_sharding_utils, param_utils, spec +from algoperf.workloads.finewebedu_lm.finewebedu_lm_jax.models import ( + ModelConfig, + TransformerDo, +) +from algoperf.workloads.finewebedu_lm.input_pipeline import get_data_iter +from algoperf.workloads.finewebedu_lm.workload import BaseLmWorkload + + +class LmWorkload(BaseLmWorkload): + """LM JAX workload.""" + + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ): + """Build an input queue using pre-cached FineWeb dataset.""" + del cache, repeat_final_dataset + ds = get_data_iter( + data_rng=data_rng, + split=split, + data_dir=data_dir, + batch_size=global_batch_size, + num_batches=num_batches, + ) + ds = map(jax_sharding_utils.shard_along_batch_dim, ds) + return ds + + def init_model_fn( + self, + rng: spec.RandomState, + dropout_rate: Optional[float] = None, + aux_dropout_rate: Optional[float] = None, + ) -> spec.ModelInitState: + # Initialize NanoDO transformer model + cfg = ModelConfig( + model_dim=self._emb_dim, # embedding dim + num_heads=self._n_heads, # num heads + seq_len=self._seq_len, + num_layers=self._n_layers, # num layers + vocab_size=self._vocab_size, + expanded_model_dim=self._mlp_dim, # feedforward dim + dtype=jnp.float32, + ) + self._model = TransformerDo(cfg) + input_shape = (1, self._seq_len) # For token IDs + + params_rng, init_rng = jax.random.split(rng) + variables = jax.jit(self._model.init)( + {'params': params_rng}, jnp.ones(input_shape, jnp.int32) + ) + params = variables['params'] + self._param_shapes = param_utils.jax_param_shapes(params) + self._param_types = param_utils.jax_param_types(self._param_shapes) + params = jax_sharding_utils.replicate(params) + model_state = None + return params, model_state + + def model_fn( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = 0.0, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + del mode, rng, update_batch_norm, model_state, dropout_rate + inputs = batch['inputs'] + # Convert one-hot inputs to token IDs if needed + if inputs.ndim == 3: # one-hot encoded + inputs = jnp.argmax(inputs, axis=-1) + logits = self._model.apply({'params': params}, inputs) + return logits, None + + def loss_fn( + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable + """Compute weighted cross entropy. + + Args: + label_batch: categorical targets [batch, length] int array. + logits_batch: [batch, length, num_classes] float array. + mask_batch: weights array of shape [batch, length]. + label_smoothing: Label smoothing factor in [0, 1]. When > 0, the target + distribution becomes (1 - label_smoothing) for the correct class and + label_smoothing / vocab_size for all other classes. Default is 0.0 (no smoothing). + + Returns: + {'summed': scalar summed loss, 'n_valid_examples': scalar number of + valid examples in batch, 'per_example': 2d array of per-example losses} + """ + if logits_batch.ndim != label_batch.ndim + 1: + raise ValueError( + f'Incorrect shapes. Got shape {logits_batch.shape} logits and ' + f'{label_batch.shape} targets.' + ) + # Compute log probabilities + log_probs = jax.nn.log_softmax(logits_batch, axis=-1) + # Extract log probability of the target class + # Shape: [batch, length] + target_log_probs = jnp.take_along_axis( + log_probs, label_batch[..., None], axis=-1 + ).squeeze(-1) + # Cross-entropy with smoothing: -(1 - α) * log_p[target] - α * mean(log_p) + # The above formula is easy to derive from the definition of label smoothing and cross-entropy loss. + confidence = 1.0 - label_smoothing + smoothing_term = label_smoothing / self._vocab_size + per_example_losses = -1.0 * ( + confidence * target_log_probs + smoothing_term * log_probs.sum(axis=-1) + ) + if mask_batch is not None: + per_example_losses = mask_batch * per_example_losses + n_valid_examples = mask_batch.sum() + else: + n_valid_examples = label_batch.shape[0] * label_batch.shape[1] + summed_loss = per_example_losses.sum() + return { + 'summed': summed_loss, + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, + } + + def _eval_batch( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> spec.Tensor: + """Evaluate the model on a single batch.""" + logits, _ = self.model_fn( + params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False + ) + metrics = self.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch['weights'], + ) + return { + 'loss': metrics['summed'], + 'denominator': metrics['n_valid_examples'], + } + + def _normalize_eval_metrics( + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: + """Normalize eval metrics.""" + del num_examples + eval_denominator = total_metrics.pop('denominator') + return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/__init__.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py new file mode 100644 index 000000000..edee8318c --- /dev/null +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py @@ -0,0 +1,344 @@ +""" +Originally based on the plainLM codebase: +https://github.com/Niccolo-Ajroldi/plainLM +under the MIT license https://github.com/Niccolo-Ajroldi/plainLM/blob/main/LICENSE. +""" + +import math +from dataclasses import dataclass +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn + + +@dataclass +class ModelConfig: + model_dim: int + num_heads: int + seq_len: int + num_layers: int + vocab_size: int + expanded_model_dim: int + multiple_of: int = 256 + rmsnorm_epsilon: float = 1e-6 + qknorm_epsilon: float = 1e-6 + use_residual_scaling: bool = True + tie_embeddings: bool = True + + +class MLP(nn.Module): + def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256): + super().__init__() + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + self.fc1 = nn.Linear(dim, 2 * hidden_dim, bias=False) + self.fc2 = nn.Linear(hidden_dim, dim, bias=False) + self.glu = nn.GLU(dim=2) + nn.init.normal_(self.fc1.weight, std=0.02) + nn.init.normal_(self.fc2.weight, std=0.02) + + def forward(self, x): + # x: (bsz, T, dim) + return self.fc2(self.glu(self.fc1(x))) + + +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0, condense_ratio: int = 1 +): + inv_freqs = 1.0 / ( + theta + ** ( + torch.arange(0, dim, 2, dtype=torch.float32, device=torch.device('cpu')) + / dim + ) + ) + t = ( + torch.arange(end, dtype=torch.float32, device=inv_freqs.device) + / condense_ratio + ) + freqs = torch.outer(t, inv_freqs).float() + return torch.stack( + [torch.cos(freqs)[None, :, None, :], torch.sin(freqs)[None, :, None, :]], + dim=4, + ) + + +def apply_rotary_emb_complex_like( + q: torch.Tensor, k: torch.Tensor, freqs_cis: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + # Rotate query and key vectors using RoPE + qk_r2 = torch.cat([q, k], dim=2).unflatten(dim=-1, sizes=(-1, 2)).float() + rotated_qk_r2 = torch.stack( + [ + qk_r2[..., 0] * freqs_cis[..., 0] - qk_r2[..., 1] * freqs_cis[..., 1], + qk_r2[..., 1] * freqs_cis[..., 0] + qk_r2[..., 0] * freqs_cis[..., 1], + ], + -1, + ).flatten(3) + rotated_qk = rotated_qk_r2 + return torch.split(rotated_qk.type_as(q), q.shape[2], dim=2) + + +class Attention(nn.Module): + def __init__(self, cfg: ModelConfig): + super().__init__() + assert cfg.model_dim % cfg.num_heads == 0 + self.dim = cfg.model_dim + self.n_heads = cfg.num_heads + self.head_dim = cfg.model_dim // cfg.num_heads + + self.w_qkv = nn.Linear(cfg.model_dim, 3 * cfg.model_dim, bias=False) + self.w_out = nn.Linear(cfg.model_dim, cfg.model_dim, bias=False) + # Split into Q, K, V sections + wq, wk, wv = torch.chunk(self.w_qkv.weight, 3, dim=0) + for w in [wq, wk, wv]: + nn.init.normal_(w, std=0.02) + nn.init.normal_(self.w_out.weight, std=0.02) + + self.eps = cfg.qknorm_epsilon # e.g., 1e-6 + seq_len = cfg.seq_len + attn_scale0 = math.log2(seq_len**2 - seq_len) + self.attn_scale = nn.Parameter(torch.tensor(attn_scale0)) + + def forward(self, x, freqs_cis): + bsz, seqlen, d = x.shape # (bsz, seqlen, d) + + q, k, v = self.w_qkv(x).split(d, dim=2) # (bsz, seqlen, d) + q = q.view( + bsz, seqlen, self.n_heads, self.head_dim + ) # (bsz, seqlen, nh, h_dim) + k = k.view( + bsz, seqlen, self.n_heads, self.head_dim + ) # (bsz, seqlen, nh, h_dim) + v = v.view( + bsz, seqlen, self.n_heads, self.head_dim + ) # (bsz, seqlen, nh, h_dim) + + q, k = apply_rotary_emb_complex_like( + q, k, freqs_cis=freqs_cis + ) # (bsz, seqlen, nh, h_dim) + + q = q.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + k = k.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + v = v.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + + # Apply QK normalization + q = q / torch.norm(q, dim=-1, keepdim=True) + self.eps + k = k / torch.norm(k, dim=-1, keepdim=True) + self.eps + q *= self.attn_scale + + out = F.scaled_dot_product_attention( + q, k, v, is_causal=True, scale=1.0 + ) # (bsz, nh, seqlen, h_dim) + out = ( + out.transpose(1, 2).contiguous().view(bsz, seqlen, d) + ) # (bsz, seqlen, d) + + return self.w_out(out) + + +class Block(nn.Module): + def __init__(self, layer_id: int, cfg: ModelConfig): + super().__init__() + self.attn = Attention(cfg) + self.attn_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) + self.mlp = MLP( + dim=cfg.model_dim, + hidden_dim=cfg.expanded_model_dim, + multiple_of=cfg.multiple_of, + ) + self.mlp_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) + self.layer_id = layer_id + + def forward(self, x, freqs_cis): + # x: (bsz, seqlen, dim) + x = x + self.attn(self.attn_norm(x), freqs_cis) + x = x + self.mlp(self.mlp_norm(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, cfg: ModelConfig): + super().__init__() + self.n_layers = cfg.num_layers + self.cfg = cfg + head_dim = cfg.model_dim // cfg.num_heads + assert cfg.model_dim % cfg.num_heads == 0 + + self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.model_dim) + self.layers = nn.ModuleList( + [Block(idx, cfg) for idx in range(cfg.num_layers)] + ) + self.out_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) + self.lm_head = nn.Linear(cfg.model_dim, cfg.vocab_size, bias=False) + + # Initialize freqs_cis on CPU first (more memory efficient) + self.register_buffer( + 'freqs_cis', + precompute_freqs_cis(head_dim, cfg.seq_len, 500000)[0 : cfg.seq_len], + persistent=False, + ) + + # init all weights, scale residual branches + self.apply(self._init_weights) + self._scale_residual_branches() + + # Move model to device (which will also move freqs_cis) + if torch.cuda.is_available(): + self.cuda() + + if cfg.tie_embeddings: + self.tie_weights() + + def forward(self, x, targets=None): + # x: (bsz, seqlen) + x = self.embed_tokens(x) # (bsz, seqlen, dim) + L = x.shape[1] + + # Make sure we have enough precomputed frequencies + if L > self.freqs_cis.shape[1]: + # Need to recompute for longer sequence + head_dim = self.cfg.model_dim // self.cfg.num_heads + new_freqs = precompute_freqs_cis( + head_dim, max(L, self.cfg.seq_len), 500000 + ) + self.register_buffer( + 'freqs_cis', new_freqs[0 : max(L, self.cfg.seq_len)], persistent=False + ) + if torch.cuda.is_available(): + self.freqs_cis = self.freqs_cis.cuda() + + # Select the frequencies for current sequence length and ensure correct device + freqs_cis = self.freqs_cis[:, :L, :].to(x.device) + + for layer in self.layers: + x = layer(x, freqs_cis) # (bsz, seqlen, dim) + out = self.lm_head(self.out_norm(x)) # (bsz, seqlen, vocab_size) + if targets is not None: + loss = F.cross_entropy( + out.view(-1, out.size(-1)), targets.view(-1), ignore_index=-100 + ) + return out, loss + return out + + def predict(self, x, k=1): + """Generate k tokens autoregressively. + + Args: + x: Input token sequence of shape (batch_size, seq_len) + k: Number of tokens to predict + + Returns: + Tuple of (input_ids, predicted_ids) + """ + + # Store original input + original_input = x.clone() + generated_input = x.clone() + + # Generate k tokens autoregressively + for i in range(k): + # Get logits for the entire sequence + logits = self(generated_input) + + # Get the logits for the last token in each sequence + next_token_logits = logits[:, -1, :] + + # Zero out the last token ID to prevent repetition + # This is a common issue - the model gets stuck repeating the last token + last_token_id = generated_input[:, -1] + next_token_logits.scatter_(1, last_token_id.unsqueeze(1), float('-inf')) + + # Get the most likely token + next_token = torch.argmax(next_token_logits, dim=-1) + + # Append the predicted token to the sequence + next_token = next_token.unsqueeze(1) # Add sequence dimension + generated_input = torch.cat([generated_input, next_token], dim=1) + + # For debugging, print predictions for the first item in the batch + print('\nPyTorch detailed prediction (first item in batch):') + predicted_sequence = generated_input[0, -k:].tolist() + print(f' Predicted token IDs: {predicted_sequence}') + for i, token_id in enumerate(predicted_sequence): + print(f' Step {i + 1}: Predicted token {token_id}') + + # Return all tokens, not just the last k + return original_input, generated_input[:, -k:] + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, std=0.02) + + def _scale_residual_branches(self): + for n, p in self.named_parameters(): + if n.endswith('fc2.weight') or n.endswith( + 'w_out.weight' + ): # mlp/glu output layer + torch.nn.init.normal_( + p, mean=0.0, std=0.02 / math.sqrt(2 * self.n_layers) + ) + + def tie_weights(self): + self.lm_head.weight = self.embed_tokens.weight + + def count_params(self, non_embedding=True): + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + n_params -= self.embed_tokens.weight.numel() + if ( + self.lm_head.weight is not self.embed_tokens.weight + ): # if no weight tying + n_params -= self.lm_head.weight.numel() + return n_params + + +def main(): + print('Initializing transformer model and running forward pass...') + + seq_length = 1024 + + # Define model configuration + config = ModelConfig( + vocab_size=50257, # Common vocab size for tokenizers like BPE or SentencePiece + seq_len=seq_length, # Maximum sequence length + model_dim=1024, # Embedding dimension + expanded_model_dim=4.0, # MLP expansion factor + num_layers=12, # Number of transformer layers + num_heads=8, # Number of attention heads + rmsnorm_epsilon=1e-6, # RMSNorm epsilon + tie_embeddings=True, # Tie embedding and output weights + ) + + # Instantiate the model + model = Transformer(config) + print(f'Model has {model.count_params():,} parameters.') + + # Create some random input data + batch_size = 2 + input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_length)) + + # Move data to the same device as the model + if torch.cuda.is_available(): + input_ids = input_ids.cuda() + + # Run a forward pass + print(f'Running forward pass with input shape: {input_ids.shape}') + logits = model(input_ids) + print(f'Output logits shape: {logits.shape}') + + # Run prediction + print('Running prediction...') + original_input, predicted_ids = model.predict(input_ids[:, :10], k=5) + print(f'Original input shape for prediction: {original_input.shape}') + print(f'Predicted IDs shape: {predicted_ids.shape}') + print(f'Predicted IDs: {predicted_ids}') + + +if __name__ == '__main__': + main() diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py new file mode 100644 index 000000000..a25ca334a --- /dev/null +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py @@ -0,0 +1,221 @@ +"""LM workload implemented in PyTorch.""" + +import contextlib +from itertools import islice +from typing import Any, Dict, Iterator, Optional, Tuple + +import jax +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + +from algoperf import param_utils, pytorch_utils, spec +from algoperf.workloads.finewebedu_lm.finewebedu_lm_pytorch.models import ( + ModelConfig, + Transformer, +) +from algoperf.workloads.finewebedu_lm.input_pipeline import get_data_iter +from algoperf.workloads.finewebedu_lm.workload import BaseLmWorkload + +USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() + + +class LmWorkload(BaseLmWorkload): + """LM PyTorch workload.""" + + def init_model_fn( + self, + rng: spec.RandomState, + dropout_rate: Optional[float] = None, + aux_dropout_rate: Optional[float] = None, + ) -> spec.ModelInitState: + if hasattr(self, '_model'): + # Reinitialize weights but keep same config + self._model.apply(self._model._init_weights) + self._model._scale_residual_branches() + return self._model, None + + torch.manual_seed(rng[0]) + cfg = ModelConfig( + vocab_size=self._vocab_size, + seq_len=self._seq_len, + model_dim=self._emb_dim, # Model dimension + expanded_model_dim=self._mlp_dim, # MLP expansion factor + num_layers=self._n_layers, # Number of transformer layers + num_heads=self._n_heads, # Number of attention heads + rmsnorm_epsilon=1e-6, + tie_embeddings=True, + ) + self._model = Transformer(cfg) + self._param_shapes = param_utils.pytorch_param_shapes(self._model) + self._param_types = param_utils.pytorch_param_types(self._param_shapes) + self._model.to(DEVICE) + + if N_GPUS > 1: + if USE_PYTORCH_DDP: + self._model = DDP(self._model, device_ids=[RANK], output_device=RANK) + else: + self._model = torch.nn.DataParallel(self._model) + + return self._model, None + + def model_fn( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = 0.0, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + del model_state, rng, update_batch_norm, dropout_rate + model = params + + # Set model to eval or train mode based on the mode parameter + if mode == spec.ForwardPassMode.EVAL: + model.eval() + elif mode == spec.ForwardPassMode.TRAIN: + model.train() + contexts = { + spec.ForwardPassMode.EVAL: torch.no_grad, + spec.ForwardPassMode.TRAIN: contextlib.nullcontext, + } + with contexts[mode](): + # Convert one-hot inputs to token IDs if needed + inputs = augmented_and_preprocessed_input_batch['inputs'] + if inputs.dim() == 3: # one-hot encoded + inputs = inputs.argmax(dim=-1) + + logits = model(inputs) + + return logits, None + + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, spec.Tensor]]: + """Build an input queue for the given split.""" + del cache, repeat_final_dataset + local_batch_size = global_batch_size // N_GPUS + loader = get_data_iter( + data_rng=data_rng, + split=split, + data_dir=data_dir, + batch_size=local_batch_size, + num_batches=num_batches, + ) + if USE_PYTORCH_DDP: + loader = islice(loader, RANK, None, N_GPUS) + dtype = torch.int32 + for batch in loader: + batch = { + 'inputs': torch.tensor(batch['inputs'], device=DEVICE, dtype=dtype), + 'targets': torch.tensor( + batch['targets'], device=DEVICE, dtype=torch.int64 + ), + 'weights': torch.tensor( + batch['weights'], device=DEVICE, dtype=torch.float32 + ) + if batch['weights'] is not None + else None, + } + yield batch + + def is_output_params(self, param_name: str) -> bool: + """Return whether the given parameter is an output parameter.""" + return 'lm_head.weight' in param_name or 'lm_head.bias' in param_name + + def loss_fn( + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: spec.Tensor, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: + """Compute weighted cross-entropy loss. + + Args: + label_batch: Target labels of shape [batch, length] (int). + logits_batch: Predicted logits of shape [batch, length, vocab_size] (float). + mask_batch: Optional weights of shape [batch, length] (float). Used to mask + out padding tokens or weight examples differently. If None, all examples + are weighted equally. + label_smoothing: Label smoothing factor in [0, 1]. When > 0, the target + distribution becomes (1 - label_smoothing) for the correct class and + label_smoothing / vocab_size for all other classes. Default is 0.0 (no smoothing). + + Returns: + Dictionary containing: + - 'summed': Scalar tensor with the sum of all weighted losses. + - 'n_valid_examples': Scalar tensor with the count of valid (non-masked) examples. + - 'per_example': Tensor of shape [batch, length] with individual losses per example. + """ + vocab_size = logits_batch.size(-1) + + # Compute cross-entropy loss with label smoothing + per_example_losses = torch.nn.functional.cross_entropy( + logits_batch.view(-1, vocab_size), + label_batch.view(-1), + reduction='none', + label_smoothing=label_smoothing, + ) + per_example_losses = per_example_losses.view_as(label_batch) + + # Apply weights if provided + if mask_batch is not None: + per_example_losses = per_example_losses * mask_batch + + # Calculate number of valid examples + n_valid_examples = ( + mask_batch.sum() + if mask_batch is not None + else torch.tensor( + label_batch.numel(), dtype=torch.float32, device=label_batch.device + ) + ) + + return { + 'summed': per_example_losses.sum(), + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, + } + + def _eval_batch( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> spec.Tensor: + """Evaluate the model on a single batch.""" + logits, _ = self.model_fn( + params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False + ) + metrics = self.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch['weights'], + ) + return { + 'loss': metrics['summed'].detach(), + 'denominator': metrics['n_valid_examples'].detach(), + } + + def _normalize_eval_metrics( + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: + """Normalize eval metrics.""" + del num_examples + if USE_PYTORCH_DDP: + for metric in total_metrics.values(): + dist.all_reduce(metric) + total_metrics = {k: v.item() for k, v in total_metrics.items()} + eval_denominator = total_metrics.pop('denominator') + return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) diff --git a/algoperf/workloads/finewebedu_lm/input_pipeline.py b/algoperf/workloads/finewebedu_lm/input_pipeline.py new file mode 100644 index 000000000..3007371fc --- /dev/null +++ b/algoperf/workloads/finewebedu_lm/input_pipeline.py @@ -0,0 +1,153 @@ +"""Input pipeline for a LM dataset.""" + +import functools +import os +from typing import Optional + +import jax +import tensorflow as tf + +from algoperf import data_utils + +AUTOTUNE = tf.data.experimental.AUTOTUNE +PAD_ID = tf.constant(-1, dtype=tf.int64) + +TFDS_SPLIT_NAME = {'train': 'train', 'eval_train': 'train', 'validation': 'val'} + +SEQUENCE_LENGTH = 1024 +MAX_CORPUS_CHARS = 1_000_000_000 +SHUFFLE_BUFFER_SIZE = 1000 +VOCAB_SIZE = 50_257 + + +def batch_with_padding( + dataset: tf.data.Dataset, + batch_size, + padded_shapes=None, + padding_id=PAD_ID, +): + """Batches a tf.data.Dataset and adds padding if len(dataset) is not divisible by the batch size. + + Args: + dataset: tf.data.Dataset + batch_size: batch size of resulting batched dataset + padded_shapes: shapes of the padded batches + padding_id: value for padding, for elements in new batch + + Returns: + """ + batched_dataset = dataset.batch(batch_size, drop_remainder=False) + + # tf.data.Dataset.padded.batch pads elements in the batch so we call it + # again with batch_size=1 to pad each element in original batch. + padded_batched_dataset = batched_dataset.padded_batch( + 1, padded_shapes=padded_shapes, padding_values=padding_id + ) + + # Remove extra dimension resulting from the batch_size=1. + padded_batched_dataset = padded_batched_dataset.unbatch() + + return padded_batched_dataset + + +def get_data_iter( + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + batch_size: int, + num_batches: Optional[int] = None, +): + ds = get_lm_dataset(data_rng, split, data_dir, batch_size, num_batches) + + it = map( + functools.partial( + data_utils.shard_and_maybe_pad_np, global_batch_size=batch_size + ), + ds, + ) + + return iter(it) + + +def get_lm_dataset( + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + batch_size: int, + num_batches: Optional[int] = None, +): + """Load preprocessed TF dataset.""" + if split not in TFDS_SPLIT_NAME: + raise NotImplementedError + + shuffle_seed = jax.random.randint(data_rng, (), -(2**31), 2**31 - 1) + + data_dir = os.path.join(data_dir, TFDS_SPLIT_NAME[split]) + tokens_ds = tf.data.Dataset.load(data_dir) + + # tokens + tokens_ds = tokens_ds.flat_map(tf.data.Dataset.from_tensor_slices) + + # sequences + sequences_ds = tokens_ds.batch(SEQUENCE_LENGTH + 1, drop_remainder=True) + + # get inputs and outputs + sequences_ds = sequences_ds.map( + lambda x: { + 'inputs': x['input_ids'][:SEQUENCE_LENGTH], + 'targets': x['input_ids'][1:], + }, + num_parallel_calls=AUTOTUNE, + ) + if split == 'train': + ds = sequences_ds.shuffle(SHUFFLE_BUFFER_SIZE, seed=shuffle_seed) + ds = ds.batch(batch_size, drop_remainder=False) + ds = ds.take(num_batches) if num_batches is not None else ds + ds = ds.repeat() + ds = ds.map( + lambda x: { + 'inputs': x['inputs'], + 'targets': x['targets'], + 'weights': None, + } + ) + ds = ds.prefetch(tf.data.experimental.AUTOTUNE) + elif split == 'eval_train': + ds = batch_with_padding( + sequences_ds, + batch_size, + padded_shapes={ + 'inputs': (batch_size, None), + 'targets': (batch_size, None), + }, + ) + ds = ds.take(num_batches) if num_batches is not None else ds + ds = ds.repeat() + ds = ds.map( + lambda x: { + 'inputs': x['inputs'], + 'targets': x['targets'], + 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0), + } + ) + ds = ds.prefetch(tf.data.experimental.AUTOTUNE) + elif split == 'validation': + ds = batch_with_padding( + sequences_ds, + batch_size, + padded_shapes={ + 'inputs': (batch_size, None), + 'targets': (batch_size, None), + }, + ) + ds = ds.take(num_batches) if num_batches is not None else ds + ds = ds.repeat() + ds = ds.map( + lambda x: { + 'inputs': x['inputs'], + 'targets': x['targets'], + 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0), + } + ) + ds = ds.prefetch(tf.data.experimental.AUTOTUNE) + return ds diff --git a/algoperf/workloads/finewebedu_lm/workload.py b/algoperf/workloads/finewebedu_lm/workload.py new file mode 100644 index 000000000..14b02e085 --- /dev/null +++ b/algoperf/workloads/finewebedu_lm/workload.py @@ -0,0 +1,193 @@ +"""LM workload parent class.""" + +import abc +import math +import os +from typing import Any, Dict, Iterator, Optional + +import jax +import numpy as np +from absl import flags + +from algoperf import spec + +FLAGS = flags.FLAGS + +USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ + + +class BaseLmWorkload(spec.Workload): + """LM workload.""" + + _vocab_size: int = 50257 + _seq_len: int = 1024 + _emb_dim: int = 1024 + _n_heads: int = 8 + _n_layers: int = 12 + _mlp_dim: int = 4096 + warmup_factor: float = 0.1 + + def __init__(self) -> None: + super().__init__() + self._param_shapes = None + self._param_types = None + + @property + def target_metric_name(self) -> str: + """The name of the target metric (useful for scoring/processing code).""" + return 'ppl' + + def has_reached_validation_target(self, eval_result: float) -> bool: + return eval_result['validation/ppl'] <= self.validation_target_value + + @property + def validation_target_value(self) -> float: + return 22.432 # Target perplexity + + def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: + return True # No test targets + + @property + def test_target_value(self) -> float: + return None # No test targets + + @property + def loss_type(self) -> spec.LossType: + return spec.LossType.SOFTMAX_CROSS_ENTROPY + + @property + def num_train_examples(self) -> int: + return 8_749_870 # sequences of 1024 tokens each + + @property + def num_eval_train_examples(self) -> int: + return 10_000 # Subset for evaluation. + + @property + def num_validation_examples(self) -> int: + return 100_000 # sequences + + @property + def num_test_examples(self) -> int: + return 0 + + @property + def eval_batch_size(self) -> int: + return 256 + + @property + def train_mean(self): + raise NotImplementedError + + @property + def train_stddev(self): + raise NotImplementedError + + @property + def max_allowed_runtime_sec(self) -> int: + return 3600 * 14 # 14 hours TODO(kasimbeg): update + + @property + def eval_period_time_sec(self) -> int: + return 1200 # 20 minutes TODO(kasimbeg): update + + @property + def step_hint(self) -> int: + """Approx. steps the baseline can do in the allowed runtime budget.""" + return 72_000 + + @property + def pre_ln(self) -> bool: + return True + + @property + def attention_temp(self) -> float: + return 1.0 + + @property + def activation(self) -> str: + return 'silu' + + @property + def glu(self) -> bool: + return True + + @abc.abstractmethod + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, Any]]: + """Build an input queue for the given split.""" + + @abc.abstractmethod + def _eval_batch( + self, + params: spec.ParameterContainer, + eval_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> Dict[str, float]: + """Evaluate the model on a single batch.""" + + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: + """Run a full evaluation of the model.""" + num_batches = int(math.ceil(num_examples / global_batch_size)) + + # Handle edge case where num_batches is 0 (e.g., test split with 0 examples) + if num_batches == 0: + return {'loss': 0.0, 'ppl': 1.0} + + if split not in self._eval_iters: + # These iterators will repeat indefinitely. + self._eval_iters[split] = self._build_input_queue( + rng, split, data_dir, global_batch_size, num_batches=num_batches + ) + + eval_metrics = {} + for _ in range(num_batches): + eval_batch = next(self._eval_iters[split]) + metrics = self._eval_batch(params, eval_batch, model_state, rng) + for metric_name, metric_value in metrics.items(): + if metric_name not in eval_metrics: + eval_metrics[metric_name] = 0.0 + eval_metrics[metric_name] += metric_value + + eval_results = self._normalize_eval_metrics(num_examples, eval_metrics) + eval_results['ppl'] = np.exp(eval_results['loss']).item() + return eval_results + + @abc.abstractmethod + def _normalize_eval_metrics( + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: + """Normalize eval metrics.""" + + @abc.abstractmethod + def loss_fn( + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: + """Compute cross-entropy loss for language modeling.""" + + def is_output_params(self, param_name: str) -> bool: + """Return whether the given parameter is an output parameter.""" + return param_name.contains('output') diff --git a/algoperf/workloads/imagenet_resnet/workload.py b/algoperf/workloads/imagenet_resnet/workload.py index ef696e328..de8458c92 100644 --- a/algoperf/workloads/imagenet_resnet/workload.py +++ b/algoperf/workloads/imagenet_resnet/workload.py @@ -103,11 +103,11 @@ def resize_size(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 66_159 # ~18.4 hours + return 49_918 # ~13.8 hours @property def eval_period_time_sec(self) -> int: - return 510 # 8.5 minutes. + return 1_996 # approx 25 evals def _build_dataset( self, diff --git a/algoperf/workloads/imagenet_vit/workload.py b/algoperf/workloads/imagenet_vit/workload.py index 2a0070ba4..4da02614f 100644 --- a/algoperf/workloads/imagenet_vit/workload.py +++ b/algoperf/workloads/imagenet_vit/workload.py @@ -88,11 +88,11 @@ def eval_batch_size(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 69_768 # ~19.4 hours + return 64_292 # ~17.8 hours @property def eval_period_time_sec(self) -> int: - return 7 * 60 # 7 mins. + return 2_571 # 7 mins. def _build_dataset( self, diff --git a/algoperf/workloads/librispeech_conformer/workload.py b/algoperf/workloads/librispeech_conformer/workload.py index 791270719..5a0a546e4 100644 --- a/algoperf/workloads/librispeech_conformer/workload.py +++ b/algoperf/workloads/librispeech_conformer/workload.py @@ -80,11 +80,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 58_015 # ~16.1 hours + return 43_680 # ~16.1 hours @property def eval_period_time_sec(self) -> int: - return 24 * 60 + return 1747 # approx 25 evals @property def step_hint(self) -> int: diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 3a320b0dd..2a8fd29d0 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -100,7 +100,11 @@ def step_hint(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 44_405 # ~12.3 hours + return 36_949 # ~12.3 hours + + @property + def eval_period_time_sec(self) -> int: + return 1447 # approx 25 evals @property def use_tanh(self) -> bool: diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 672f3440f..c6bb149f7 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -96,7 +96,11 @@ def step_hint(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 44_405 # ~12.3 hours + return 36_949 # 10.3 hours + + @property + def eval_period_time_sec(self) -> int: + return 1447 # approx 25 evals @property def use_tanh(self) -> bool: diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index 8717e46d6..771b103a0 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -88,11 +88,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 12_011 # ~3.3 hours + return 11_303 # ~3.1 hours @property def eval_period_time_sec(self) -> int: - return 4 * 60 + return 452 # approx 25 evals def _build_input_queue( self, diff --git a/algoperf/workloads/wmt/workload.py b/algoperf/workloads/wmt/workload.py index 40e4262dd..2e232214e 100644 --- a/algoperf/workloads/wmt/workload.py +++ b/algoperf/workloads/wmt/workload.py @@ -89,11 +89,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 43_336 # ~12.0 hours + return 16_114 # ~12.0 hours @property def eval_period_time_sec(self) -> int: - return 14 * 60 + return 644 @property def step_hint(self) -> int: diff --git a/algoperf/workloads/workloads.py b/algoperf/workloads/workloads.py index 4dd4717e9..e90300a36 100644 --- a/algoperf/workloads/workloads.py +++ b/algoperf/workloads/workloads.py @@ -113,6 +113,10 @@ 'workload_path': 'librispeech_deepspeech/librispeech', 'workload_class_name': 'LibriSpeechDeepSpeechNormAndSpecAugWorkload', }, + 'finewebedu_lm': { + 'workload_path': 'finewebedu_lm/finewebedu_lm', + 'workload_class_name': 'LmWorkload', + }, 'mnist': { 'workload_path': 'mnist/mnist', 'workload_class_name': 'MnistWorkload', @@ -152,6 +156,7 @@ 'imagenet_vit', 'librispeech_conformer', 'librispeech_deepspeech', + 'finewebedu_lm', 'ogbg', 'wmt', ] diff --git a/algorithms/archived_paper_baselines/adamw/pytorch/submission.py b/algorithms/archived_paper_baselines/adamw/pytorch/submission.py index 761ce5cb1..7c50ff4ff 100644 --- a/algorithms/archived_paper_baselines/adamw/pytorch/submission.py +++ b/algorithms/archived_paper_baselines/adamw/pytorch/submission.py @@ -189,6 +189,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'finewebedu_lm': + return 64 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/algorithms/archived_paper_baselines/nesterov/jax/submission.py b/algorithms/archived_paper_baselines/nesterov/jax/submission.py index e199fb2b9..061acc3de 100644 --- a/algorithms/archived_paper_baselines/nesterov/jax/submission.py +++ b/algorithms/archived_paper_baselines/nesterov/jax/submission.py @@ -292,6 +292,8 @@ def get_batch_size(workload_name): return 16 elif workload_name == 'cifar': return 128 + elif workload_name == 'finewebedu_lm': + return 64 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py index 0577cd4e0..323022598 100644 --- a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py @@ -394,6 +394,8 @@ def get_batch_size(workload_name): return 512 elif workload_name == 'wmt': return 128 + elif workload_name == 'finewebedu_lm': + return 64 elif workload_name == 'mnist': return 16 else: diff --git a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py index 0b32199ba..2abf74c73 100644 --- a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -372,6 +372,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'finewebedu_lm': + return 64 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py b/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py new file mode 100644 index 000000000..b7adf6cd6 --- /dev/null +++ b/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py @@ -0,0 +1,427 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" + +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union, +) + +# isort: on +import chex +import jax +import jax.numpy as jnp +import optax + +from algoperf import jax_sharding_utils, spec + +_GRAD_CLIP_EPS = 1e-6 + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + + Args: + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate), + ) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this). + + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree.map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat + ) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree.map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments + ) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + + def jax_cosine_warmup(step_hint: int, hyperparameters): + step_hint = 0.75 * step_hint + # Create learning rate schedule. + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps + ) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps] + ) + return schedule_fn + + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) + optimizer_state = opt_init_fn(params_zeros_like) + + return optimizer_state, opt_update_fn + + +def train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, + dropout_rate, +): + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + dropout_rate=dropout_rate, + ) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container + ) + # Compute mean loss and grad + loss = summed_loss / n_valid_examples + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del train_state + del eval_results + + optimizer_state, opt_update_fn = optimizer_state + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + dropout_rate = hyperparameters.dropout_rate + + # Create shardings for each argument + replicated = jax_sharding_utils.get_replicate_sharding() # No partitioning + sharded = ( + jax_sharding_utils.get_batch_dim_sharding() + ) # Partition along batch dimension + + # Create the sharding rules for each argument + arg_shardings = ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + replicated, # rng + replicated, # grad_clip + replicated, # label_smoothing + replicated, # dropout_rate + ) + out_shardings = ( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated, # grad_norm + ) + # Jit with shardings + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings, + ) + + new_optimizer_state, new_params, new_model_state, loss, grad_norm = ( + jitted_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, + dropout_rate, + ) + ) + + # Log loss, grad_norm. + if global_step % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + {'loss': loss.item(), 'grad_norm': grad_norm.item()}, global_step + ) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'finewebedu_lm': + return 64 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/algorithms/target_setting_algorithms/fineweb_edu_lm/pytorch_nadamw_target_setting.py b/algorithms/target_setting_algorithms/fineweb_edu_lm/pytorch_nadamw_target_setting.py new file mode 100644 index 000000000..b881747d8 --- /dev/null +++ b/algorithms/target_setting_algorithms/fineweb_edu_lm/pytorch_nadamw_target_setting.py @@ -0,0 +1,403 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" + +import math +from typing import Any, Dict, Iterator, List, Optional, Tuple + +import torch +import torch.distributed.nn as dist_nn +from absl import logging +from torch import Tensor +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR + +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup + +USE_PYTORCH_DDP = pytorch_setup()[0] + + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2 + ): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = { + 'lr': lr, + 'betas': betas, + 'eps': eps, + 'weight_decay': weight_decay, + } + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step'] + ) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.0) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + ) + + return loss + + +def nadamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, +) -> None: + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors' + ) + + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # Update step. + step_t += 1 + + # Perform stepweight decay. + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + # We undo these ops later on, which could cause numerical issues but saves + # us from having to make an extra copy of the gradients. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + step = step_t.item() + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = math.sqrt(bias_correction2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) + + +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_state + del rng + + optimizer_state = { + 'optimizer': NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ), + } + + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + step_hint = step_hint * 0.75 + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps + ) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps] + ) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint, hyperparameters, optimizer_state['optimizer'] + ) + + return optimizer_state + + +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del train_state + del eval_results + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + dropout_rate=hyperparameters.dropout_rate, + ) + + label_smoothing = ( + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip + ) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) + + return (optimizer_state, current_param_container, new_model_state) + + +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + elif workload_name == 'finewebedu_lm': + return 64 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json b/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json new file mode 100644 index 000000000..ce0f75623 --- /dev/null +++ b/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json @@ -0,0 +1,11 @@ +[ + { + "dropout_rate": 0.0, + "label_smoothing": 0.0, + "learning_rate": 0.00038418421332238876, + "one_minus_beta1": 0.01564758865, + "beta2": 0.992362328914093, + "weight_decay": 0.25551270901641954, + "warmup_factor": 0.05 + } +] \ No newline at end of file diff --git a/datasets/README.md b/dataset/README.md similarity index 99% rename from datasets/README.md rename to dataset/README.md index 1aeb83239..1bfd9bf73 100644 --- a/datasets/README.md +++ b/dataset/README.md @@ -453,3 +453,13 @@ The preprocessing script will generate `.npy` files for audio data, `features.cs ```bash python3 librispeech_preprocess.py --data_dir=$DATA_DIR/librispeech --tokenizer_vocab_path=$DATA_DIR/librispeech/spm_model.vocab ``` + +### Fineweb-EDU 10B +From `algorithmic-efficiency` run: + +```bash +python3 datasets/dataset_setup.py \ +--data_dir $DATA_DIR \ +--temp_dir $DATA_DIR/tmp \ +--fineweb_edu +``` \ No newline at end of file diff --git a/datasets/dataset_setup.py b/dataset/dataset_setup.py similarity index 89% rename from datasets/dataset_setup.py rename to dataset/dataset_setup.py index e110930cd..de5e9d271 100644 --- a/datasets/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -56,7 +56,7 @@ Example command: -python3 datasets/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir=~/data \ --temp_dir=/tmp/mlcommons_data --imagenet \ @@ -73,8 +73,11 @@ from algoperf.workloads.wmt import tokenizer from algoperf.workloads.wmt.input_pipeline import normalize_feature_names -from datasets import librispeech_preprocess -from datasets import librispeech_tokenizer +from dataset import librispeech_preprocess +from dataset import librispeech_tokenizer + +import datasets as hf_datasets +from transformers import AutoTokenizer import functools import os @@ -82,6 +85,7 @@ import subprocess import tarfile +from typing import Dict, List, Any from absl import app from absl import flags from absl import logging @@ -121,6 +125,9 @@ flags.DEFINE_boolean( 'fastmri', False, 'If --all=false, whether or not to download FastMRI.' ) +flags.DEFINE_boolean( + 'finewebedu', False, 'If --all=false, whether or not to download FineWebEdu.' +) flags.DEFINE_boolean( 'imagenet', False, 'If --all=false, whether or not to download Imagenet.' ) @@ -194,6 +201,9 @@ flags.DEFINE_string('framework', None, 'Can be either jax or pytorch.') flags.DEFINE_boolean('skip_download', False, 'Skips data download.') +flags.DEFINE_boolean( + 'skip_tokenization', False, 'Skip Fineweb-edu tokenization.' +) FLAGS = flags.FLAGS @@ -767,6 +777,102 @@ def download_wmt(data_dir): ) +def download_finewebedu( + data_dir, tmp_dir=None, skip_download=False, skip_tokenization=False +): + """Download FineWebEdu-10B.""" + + if not skip_download: + data_dir = os.path.join(data_dir, 'fineweb_edu_10B') + tmp_dir = tmp_dir if tmp_dir is not None else '/tmp' + cache_dir = ( + os.path.join(tmp_dir, 'lm') + if tmp_dir is not None + else os.path.expanduser('~/.cache/huggingface/datasets') + ) + + _maybe_mkdir(data_dir) + _maybe_mkdir(tmp_dir) + _maybe_mkdir(cache_dir) + + os.environ['TMPDIR'] = tmp_dir + + ds = hf_datasets.load_dataset( + 'HuggingFaceFW/fineweb-edu', + name='sample-10BT', + split='train', + cache_dir=cache_dir, + ) + ds.save_to_disk(os.path.join(tmp_dir, 'fwedu_10B_raw')) + else: + ds = hf_datasets.load_from_disk(os.path.join(tmp_dir, 'fwedu_10B_raw')) + + if not skip_tokenization: + # Tokenize + lm_tokenizer = AutoTokenizer.from_pretrained('gpt2') + logging.info(f'Vocab size of lm_tokenizer = {len(lm_tokenizer)}') + + def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + def add_eos(seq): + return seq + lm_tokenizer.eos_token if seq else seq + + def add_eos_batched(seqs): + return [add_eos(seq) for seq in seqs] + + return lm_tokenizer( + add_eos_batched(examples['text']), + return_special_tokens_mask=False, + return_attention_mask=False, + ) + + lm_tokenizer.model_max_length = ( + 1e30 # prevent truncation during tokenization + ) + logging.info('Tokenizing...') + tokenized_dataset = ds.map( + tokenize, + remove_columns=[ + 'text', + 'id', + 'dump', + 'url', + 'file_path', + 'language', + 'language_score', + 'token_count', + 'score', + 'int_score', + ], + batched=True, + batch_size=1024, + num_proc=8, + ) + + tokenized_dataset.save_to_disk( + os.path.join(data_dir, 'fwedu_10B_tokenized') + ) + else: + tokenized_dataset = hf_datasets.load_from_disk( + os.path.join(data_dir, 'fwedu_10B_tokenized') + ) + + # Convert to tensorflow_datasets.Dataset objects + tokenized_dataset = tokenized_dataset.to_tf_dataset() + + # Shuffle dataset + dataset_size = tokenized_dataset.cardinality().numpy() + shuffled_dataset = tokenized_dataset.shuffle(dataset_size, seed=0) + train_size = int(0.9 * dataset_size) + train_dataset = shuffled_dataset.take(train_size) + val_dataset = shuffled_dataset.skip(train_size) + + # Split in train and valid. + train_dataset.save(os.path.join(data_dir, 'train')) + val_dataset.save(os.path.join(data_dir, 'val')) + + return + + def main(_): data_dir = FLAGS.data_dir tmp_dir = FLAGS.temp_dir @@ -854,6 +960,12 @@ def main(_): logging.info('Downloading WMT...') download_wmt(data_dir) + if FLAGS.all or FLAGS.finewebedu: + logging.info('Downloading FineWebEdu-10B...') + download_finewebedu( + data_dir, tmp_dir, FLAGS.skip_download, FLAGS.skip_tokenization + ) + # pylint: enable=logging-format-interpolation # pylint: enable=consider-using-with diff --git a/datasets/librispeech_preprocess.py b/dataset/librispeech_preprocess.py similarity index 99% rename from datasets/librispeech_preprocess.py rename to dataset/librispeech_preprocess.py index 1c216db46..878f10f2a 100644 --- a/datasets/librispeech_preprocess.py +++ b/dataset/librispeech_preprocess.py @@ -14,7 +14,7 @@ from absl import logging from pydub import AudioSegment -from datasets import librispeech_tokenizer +from dataset import librispeech_tokenizer gfile = tf.io.gfile copy = tf.io.gfile.copy diff --git a/datasets/librispeech_tokenizer.py b/dataset/librispeech_tokenizer.py similarity index 100% rename from datasets/librispeech_tokenizer.py rename to dataset/librispeech_tokenizer.py diff --git a/docker/build_docker_images.sh b/docker/build_docker_images.sh index 6b5e67ceb..aa94222ea 100644 --- a/docker/build_docker_images.sh +++ b/docker/build_docker_images.sh @@ -27,7 +27,7 @@ then GIT_BRANCH='main' # Set default argument fi -FRAMEWORKS=( "jax" "pythorch" "both" ) +FRAMEWORKS=( "jax" "pytorch") if [[ -n "$FRAMEWORK" ]]; then diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index 35ac30461..d92107e90 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -174,7 +174,7 @@ fi # Check if arguments are valid VALID_DATASETS=("criteo1tb" "imagenet" "fastmri" "ogbg" "librispeech" \ - "wmt" "mnist") + "wmt" "mnist" "fineweb_edu_10B") VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_resnet_silu" "imagenet_resnet_gelu" \ "imagenet_resnet_large_bn_init" "imagenet_vit" "imagenet_vit_glu" \ "imagenet_vit_post_ln" "imagenet_vit_map" "fastmri" "ogbg" \ @@ -185,7 +185,7 @@ VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_resnet_silu" "imagenet_ "librispeech_conformer_gelu" "fastmri_model_size" "fastmri_tanh" \ "librispeech_deepspeech_tanh" \ "librispeech_deepspeech_no_resnet" "librispeech_deepspeech_norm_and_spec_aug" - "fastmri_layernorm" "ogbg_gelu" "ogbg_silu" "ogbg_model_size") + "fastmri_layernorm" "ogbg_gelu" "ogbg_silu" "ogbg_model_size" "lm") VALID_RULESETS=("self" "external") # Set data and experiment paths @@ -221,7 +221,7 @@ TUNING_RULESET_FLAG="--tuning_ruleset=${TUNING_RULESET}" if [[ "${FRAMEWORK}" == "jax" ]]; then COMMAND_PREFIX="python" else - COMMAND_PREFIX="torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8" + COMMAND_PREFIX="torchrun --redirects 1:0,2:0,3:0 --standalone --nnodes=1 --nproc_per_node=4" fi # Set data directory and bucket (bucket is only relevant in internal mode) diff --git a/pyproject.toml b/pyproject.toml index e4de98f89..b93c9794e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,7 +70,9 @@ version_file = "algoperf/_version.py" ############################################################################### [project.optional-dependencies] # All workloads -full = ["algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt]"] +full = [ + "algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt,lm]", +] # All workloads plus development dependencies full_dev = ["algoperf[full,dev]"] # Dependencies for developing the package @@ -88,6 +90,7 @@ librispeech_conformer = [ "pydub==0.25.1", ] wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.19.0"] +lm = ["transformers==4.26", "datasets==3.6.0"] # Frameworks jax_core_deps = [ diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index 4f2ae9c57..043a65791 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -71,6 +71,7 @@ 'wer', 'l1_loss', 'loss', + 'ppl', ] MAX_EVAL_METRICS = ['mean_average_precision', 'ssim', 'accuracy', 'bleu'] diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 3423df2e1..4b7bed2b5 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -123,6 +123,8 @@ def get_summary_df(workload, workload_df, include_test_split=False): workload_df['accumulated_submission_time'] / workload_df['global_step'] ).iloc[-1][-1] + summary_df['step_hint'] = scoring_utils.get_workload_stephint(workload) + # test metrics if include_test_split: test_metric, test_target = scoring_utils.get_workload_metrics_and_targets( @@ -157,7 +159,7 @@ def get_summary_df(workload, workload_df, include_test_split=False): return summary_df -def get_submission_summary(df, include_test_split=True): +def get_submission_summary(df, include_test_split=False): """Summarizes the submission results into metric and time tables organized by workload. """ diff --git a/scoring/scoring_utils.py b/scoring/scoring_utils.py index 5be6c790c..cb63eab4b 100644 --- a/scoring/scoring_utils.py +++ b/scoring/scoring_utils.py @@ -240,3 +240,23 @@ def get_workload_metrics_and_targets(workload, split='validation'): metric = f'test/{metric_name}' target = workload_obj.test_target_value return metric, target + + +def get_workload_stephint(workload): + workload_name = re.match(WORKLOAD_NAME_PATTERN, workload).group(1) + framework = re.match(WORKLOAD_NAME_PATTERN, workload).group(2) + workload_metadata = copy.copy(WORKLOADS[workload_name]) + + # Extend path according to framework. + workload_metadata['workload_path'] = os.path.join( + BASE_WORKLOADS_DIR, + workload_metadata['workload_path'] + f'{framework}', + 'workload.py', + ) + workload_init_kwargs = {} + workload_obj = workloads_registry.import_workload( + workload_path=workload_metadata['workload_path'], + workload_class_name=workload_metadata['workload_class_name'], + workload_init_kwargs=workload_init_kwargs, + ) + return workload_obj.step_hint diff --git a/scoring/utils/run_workloads.py b/scoring/utils/run_workloads.py index 273881c5a..d8e0172fa 100644 --- a/scoring/utils/run_workloads.py +++ b/scoring/utils/run_workloads.py @@ -241,7 +241,8 @@ def main(_): # For each runnable workload check if there are any containers running and if not launch next container command for workload in workloads: - run_key = prng.fold_in(rng_subkey, hash(workload)) + workload_foldin = hash(workload) % 9 + run_key = prng.fold_in(rng_subkey, workload_foldin) run_seed = run_key[0] # arbitrary base_workload_name = get_base_workload_name(workload) wait_until_container_not_running() @@ -270,6 +271,7 @@ def main(_): 'docker run -t -d -v /home/kasimbeg/data/:/data/ ' '-v /home/kasimbeg/experiment_runs/:/experiment_runs ' '-v /home/kasimbeg/experiment_runs/logs:/logs ' + '-v /home/kasimbeg/algorithmic-efficiency:/algorithmic-efficiency ' f'{mount_repo_flag}' '--gpus all --ipc=host ' f'{docker_image_url} ' diff --git a/scoring/utils/workload_metadata_external_tuning.json b/scoring/utils/workload_metadata_external_tuning.json index c7d4ae195..0ba0d99ee 100644 --- a/scoring/utils/workload_metadata_external_tuning.json +++ b/scoring/utils/workload_metadata_external_tuning.json @@ -24,11 +24,15 @@ "dataset": "librispeech" }, "criteo1tb": { - "max_steps": 10666, + "max_steps": 15666, "dataset": "criteo1tb" }, "librispeech_conformer": { "max_steps": 80000, "dataset": "librispeech" + }, + "finewebedu_lm" : { + "max_steps": 55000, + "dataset":"fineweb_edu_10B" } } diff --git a/submission_runner.py b/submission_runner.py index 552c99b79..01d9894d8 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -266,6 +266,7 @@ def train_once( 'librispeech_deepspeech', 'ogbg', 'wmt', + 'finewebedu_lm', ] base_workload = workloads.get_base_workload_name(workload_name) if base_workload in compile_error_workloads: @@ -783,7 +784,10 @@ def main(_): os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' if FLAGS.framework == 'pytorch': - pytorch_init(USE_PYTORCH_DDP, RANK, profiler) + limit_tf_threads = base_workload != 'finewebedu_lm' + pytorch_init( + USE_PYTORCH_DDP, RANK, profiler, limit_tf_threads=limit_tf_threads + ) # TODO: remove once issue resolved. if FLAGS.pytorch_eval_num_workers != 0: @@ -799,6 +803,7 @@ def main(_): 'librispeech_deepspeech', 'imagenet_vit', 'criteo1tb', + 'finewebedu_lm', ]: os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80' diff --git a/tests/modeldiffs/lm/compare.py b/tests/modeldiffs/lm/compare.py new file mode 100644 index 000000000..709e3125f --- /dev/null +++ b/tests/modeldiffs/lm/compare.py @@ -0,0 +1,892 @@ +""" +Test file to verify that JAX and PyTorch implementations produce identical outputs +when given the same weights and inputs. + +Tests are performed module-by-module: +1. RMSNorm +2. RoPE (Rotary Position Embeddings) +3. MLP +4. Attention +5. Transformer Block +6. Full Model +""" + +import os +import sys + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import jax.numpy as jnp +import numpy as np +import torch +import torch.nn.functional as F +from absl import flags, logging +from absl.testing import absltest, parameterized + +# Import JAX implementation +from algoperf.workloads.finewebedu_lm.finewebedu_lm_jax.models import ( + CausalAttn, + Mlp, + TBlock, + TransformerDo, + apply_rope, + init_rope, +) +from algoperf.workloads.finewebedu_lm.finewebedu_lm_jax.models import ( + ModelConfig as JaxModelConfig, +) + +# Import PyTorch implementation +from algoperf.workloads.finewebedu_lm.finewebedu_lm_pytorch.models import ( + MLP, + Attention, + Block, + Transformer, + apply_rotary_emb_complex_like, + precompute_freqs_cis, +) +from algoperf.workloads.finewebedu_lm.finewebedu_lm_pytorch.models import ( + ModelConfig as PyTorchModelConfig, +) + +FLAGS = flags.FLAGS +# Needed to avoid UnparsedFlagAccessError +FLAGS(sys.argv) + +# ============================================================================ +# Helper Functions +# ============================================================================ + + +def assert_close(jax_output, torch_output, rtol=1e-5, atol=1e-6, name=''): + """Assert that JAX and PyTorch outputs are close.""" + jax_np = np.array(jax_output) + torch_np = torch_output.detach().cpu().numpy() + + mse = np.mean((jax_np - torch_np) ** 2) + max_diff = np.max(np.abs(jax_np - torch_np)) + + logging.info(f'\n{name} Comparison:') + logging.info(f' MSE: {mse:.8e}') + logging.info(f' Max Difference: {max_diff:.8e}') + + np.testing.assert_allclose( + jax_np, + torch_np, + rtol=rtol, + atol=atol, + err_msg=f'{name} outputs do not match', + ) + + +# ============================================================================ +# Test Functions (unchanged) +# ============================================================================ + + +def test_rmsnorm(): + """Test that RMSNorm produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing RMSNorm') + logging.info('=' * 70) + + batch_size, seq_len, dim = 2, 10, 256 + eps = 1e-6 + + # Create random input + np_input = np.random.randn(batch_size, seq_len, dim).astype(np.float32) + + # Initialize PyTorch RMSNorm + torch_norm = torch.nn.RMSNorm(dim, eps=eps) + torch_input = torch.tensor(np_input) + + # Initialize JAX RMSNorm (using Flax's RMSNorm from nanodo) + from flax import linen as nn + + flax_norm = nn.RMSNorm(epsilon=eps) + jax_input = jnp.array(np_input) + flax_params = flax_norm.init(jax.random.PRNGKey(0), jax_input) + + # Copy weights from PyTorch to JAX + with torch.no_grad(): + flax_params['params']['scale'] = jnp.array(torch_norm.weight.numpy()) + + # Forward pass + with torch.no_grad(): + torch_output = torch_norm(torch_input) + + jax_output = flax_norm.apply(flax_params, jax_input) + + # Compare + assert_close(jax_output, torch_output, name='RMSNorm') + logging.info('✓ RMSNorm test passed') + + +def test_rope(): + """Test that RoPE produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing RoPE (Rotary Position Embeddings)') + logging.info('=' * 70) + + batch_size, seq_len, n_heads, dim = 2, 16, 4, 128 + head_dim = dim // n_heads + + # Initialize RoPE + torch_freqs = precompute_freqs_cis(head_dim, seq_len, theta=500000) + jax_freqs = init_rope(dim, seq_len, n_heads) + + # Create random Q and K + np_q = np.random.randn(batch_size, seq_len, n_heads, head_dim).astype( + np.float32 + ) + np_k = np.random.randn(batch_size, seq_len, n_heads, head_dim).astype( + np.float32 + ) + + # PyTorch forward + torch_q = torch.tensor(np_q) + torch_k = torch.tensor(np_k) + with torch.no_grad(): + torch_q_rot, torch_k_rot = apply_rotary_emb_complex_like( + torch_q, torch_k, freqs_cis=torch_freqs + ) + + # JAX forward + jax_q = jnp.array(np_q) + jax_k = jnp.array(np_k) + jax_q_rot, jax_k_rot = apply_rope(jax_q, jax_k, jax_freqs) + + # Compare + assert_close(jax_q_rot, torch_q_rot, name='RoPE Q') + assert_close(jax_k_rot, torch_k_rot, name='RoPE K') + logging.info('✓ RoPE test passed') + + +def copy_mlp_params(pytorch_mlp, flax_params): + """Copy MLP parameters from PyTorch to JAX.""" + new_params = flax_params.copy() + + # Handle compiled models + if hasattr(pytorch_mlp, '_orig_mod'): + pytorch_mlp = pytorch_mlp._orig_mod + + # Copy fc1 and fc2 weights (transposed for JAX) + new_params['params']['Dense_0']['kernel'] = ( + pytorch_mlp.fc1.weight.detach().numpy().T + ) + new_params['params']['Dense_1']['kernel'] = ( + pytorch_mlp.fc2.weight.detach().numpy().T + ) + + return new_params + + +def test_mlp(): + """Test that MLP produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing MLP') + logging.info('=' * 70) + + batch_size, seq_len, dim = 2, 10, 256 + hidden_dim = 1024 + + # Initialize PyTorch MLP + pytorch_mlp = MLP(dim=dim, hidden_dim=hidden_dim) + + # Initialize JAX MLP + cfg = JaxModelConfig( + model_dim=dim, + num_heads=4, + seq_len=128, + num_layers=2, + vocab_size=1000, + expanded_model_dim=hidden_dim, + dtype=jnp.float32, + rmsnorm_epsilon=1e-6, + ) + flax_mlp = Mlp(cfg) + + # Initialize JAX params + dummy_input = jnp.ones((batch_size, seq_len, dim)) + flax_params = flax_mlp.init(jax.random.PRNGKey(0), dummy_input) + + # Copy weights + flax_params = copy_mlp_params(pytorch_mlp, flax_params) + + # Create input + np_input = np.random.randn(batch_size, seq_len, dim).astype(np.float32) + torch_input = torch.tensor(np_input) + jax_input = jnp.array(np_input) + + # Forward pass + with torch.no_grad(): + torch_output = pytorch_mlp(torch_input) + + jax_output = flax_mlp.apply(flax_params, jax_input) + + # Compare + assert_close(jax_output, torch_output, name='MLP') + logging.info('✓ MLP test passed') + + +def copy_attention_params(pytorch_attn, flax_params): + """Copy attention parameters from PyTorch to JAX.""" + # Handle compiled models + if hasattr(pytorch_attn, '_orig_mod'): + pytorch_attn = pytorch_attn._orig_mod + + n_heads = pytorch_attn.n_heads + head_dim = pytorch_attn.head_dim + dim = pytorch_attn.dim + + # Split PyTorch's combined qkv weights + w_qkv = pytorch_attn.w_qkv.weight + q_weight, k_weight, v_weight = [ + u.detach().numpy() for u in w_qkv.split(dim, dim=0) + ] + + # Reshape for Flax's DenseGeneral format [D, H, Dh] + def reshape_for_flax(w, n_heads, head_dim): + return w.reshape(n_heads, head_dim, -1).transpose(2, 0, 1) + + new_params = { + 'query': {'kernel': reshape_for_flax(q_weight, n_heads, head_dim)}, + 'key': {'kernel': reshape_for_flax(k_weight, n_heads, head_dim)}, + 'value': {'kernel': reshape_for_flax(v_weight, n_heads, head_dim)}, + 'attn_out_proj': {'kernel': pytorch_attn.w_out.weight.detach().numpy().T}, + } + + return {'params': new_params} + + +def test_attention(): + """Test that Attention produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Attention') + logging.info('=' * 70) + + batch_size, seq_len, dim, n_heads = 2, 16, 256, 4 + + # Initialize PyTorch Attention + config = PyTorchModelConfig( + vocab_size=1000, + seq_len=seq_len, + model_dim=dim, + expanded_model_dim=1024, + num_layers=1, + num_heads=n_heads, + rmsnorm_epsilon=1e-6, + ) + pytorch_attn = Attention(config) + freqs_cis = precompute_freqs_cis(dim // n_heads, seq_len, theta=500000) + + # Initialize JAX Attention + cfg = JaxModelConfig( + model_dim=dim, + num_heads=n_heads, + seq_len=seq_len, + num_layers=1, + vocab_size=1000, + expanded_model_dim=1024, + dtype=jnp.float32, + rmsnorm_epsilon=1e-6, + ) + flax_attn = CausalAttn(cfg) + + # Initialize JAX params + dummy_input = jnp.ones((batch_size, seq_len, dim)) + flax_params = flax_attn.init(jax.random.PRNGKey(0), dummy_input) + + # Copy weights + flax_params = copy_attention_params(pytorch_attn, flax_params) + + # Create input + np_input = np.random.randn(batch_size, seq_len, dim).astype(np.float32) + torch_input = torch.tensor(np_input) + jax_input = jnp.array(np_input) + + # Forward pass + with torch.no_grad(): + torch_output = pytorch_attn(torch_input, freqs_cis) + + jax_output = flax_attn.apply(flax_params, jax_input) + + # Compare + assert_close(jax_output, torch_output, rtol=1e-4, atol=1e-5, name='Attention') + logging.info('✓ Attention test passed') + + +def copy_block_params(pytorch_block, flax_params): + """Copy block parameters from PyTorch to JAX.""" + # Copy attention parameters + attn_params = copy_attention_params(pytorch_block.attn, {'params': {}})[ + 'params' + ] + + # Copy MLP parameters + pytorch_mlp = pytorch_block.mlp + mlp_params = { + 'Dense_0': {'kernel': pytorch_mlp.fc1.weight.detach().numpy().T}, + 'Dense_1': {'kernel': pytorch_mlp.fc2.weight.detach().numpy().T}, + } + + # Copy RMSNorm parameters + norm_params = { + 'attn_norm': {'scale': pytorch_block.attn_norm.weight.detach().numpy()}, + 'mlp_norm': {'scale': pytorch_block.mlp_norm.weight.detach().numpy()}, + } + + return { + 'params': { + 'CausalAttn_0': attn_params, + 'Mlp_0': mlp_params, + 'RMSNorm_0': norm_params['attn_norm'], + 'RMSNorm_1': norm_params['mlp_norm'], + } + } + + +def test_block(): + """Test that Transformer Block produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Transformer Block') + logging.info('=' * 70) + + batch_size, seq_len, dim, n_heads = 2, 16, 256, 4 + expand = 4.0 + + # Initialize PyTorch Block + config = PyTorchModelConfig( + vocab_size=1000, + seq_len=seq_len, + model_dim=dim, + expanded_model_dim=int(dim * expand), + num_layers=1, + num_heads=n_heads, + rmsnorm_epsilon=1e-6, + ) + pytorch_block = Block(layer_id=0, cfg=config) + freqs_cis = precompute_freqs_cis(dim // n_heads, seq_len, theta=500000) + + # Initialize JAX Block + cfg = JaxModelConfig( + model_dim=dim, + num_heads=n_heads, + seq_len=seq_len, + num_layers=1, + vocab_size=1000, + expanded_model_dim=int(dim * expand), + dtype=jnp.float32, + rmsnorm_epsilon=1e-6, + ) + flax_block = TBlock(cfg) + + # Initialize JAX params + dummy_input = jnp.ones((batch_size, seq_len, dim)) + flax_params = flax_block.init(jax.random.PRNGKey(0), dummy_input) + + # Copy weights + flax_params = copy_block_params(pytorch_block, flax_params) + + # Create input + np_input = np.random.randn(batch_size, seq_len, dim).astype(np.float32) + torch_input = torch.tensor(np_input) + jax_input = jnp.array(np_input) + + # Forward pass + with torch.no_grad(): + torch_output = pytorch_block(torch_input, freqs_cis) + + jax_output = flax_block.apply(flax_params, jax_input) + + # Compare + assert_close(jax_output, torch_output, rtol=1e-4, atol=1e-5, name='Block') + logging.info('✓ Block test passed') + + +def copy_full_model_params(pytorch_model, flax_params, config): + """Copy all parameters from PyTorch model to JAX model.""" + # Handle tied embeddings case + if hasattr(pytorch_model, '_orig_mod'): + pytorch_model = pytorch_model._orig_mod + + n_layers = config.num_layers + n_heads = config.num_heads + dim = config.model_dim + head_dim = dim // n_heads + + new_params = {'params': {}} + + # Copy embedding weights + new_params['params']['embed'] = { + 'embedding': pytorch_model.embed_tokens.weight.detach().numpy() + } + + # Copy each transformer block + for i in range(n_layers): + pytorch_block = pytorch_model.layers[i] + + # Attention params + w_qkv = pytorch_block.attn.w_qkv.weight + q_weight, k_weight, v_weight = [ + u.detach().numpy() for u in w_qkv.split(dim, dim=0) + ] + + def reshape_for_flax(w, n_heads, head_dim): + return w.reshape(n_heads, head_dim, -1).transpose(2, 0, 1) + + attn_params = { + 'query': {'kernel': reshape_for_flax(q_weight, n_heads, head_dim)}, + 'key': {'kernel': reshape_for_flax(k_weight, n_heads, head_dim)}, + 'value': {'kernel': reshape_for_flax(v_weight, n_heads, head_dim)}, + 'attn_out_proj': { + 'kernel': pytorch_block.attn.w_out.weight.detach().numpy().T + }, + } + + # MLP params + mlp_params = { + 'Dense_0': {'kernel': pytorch_block.mlp.fc1.weight.detach().numpy().T}, + 'Dense_1': {'kernel': pytorch_block.mlp.fc2.weight.detach().numpy().T}, + } + + # Norm params + attn_norm = {'scale': pytorch_block.attn_norm.weight.detach().numpy()} + mlp_norm = {'scale': pytorch_block.mlp_norm.weight.detach().numpy()} + + # Assemble block params + block_key = f'blocks_{i}' + new_params['params'][block_key] = { + 'CausalAttn_0': attn_params, + 'Mlp_0': mlp_params, + 'RMSNorm_0': attn_norm, + 'RMSNorm_1': mlp_norm, + } + + # Copy output norm + new_params['params']['out_ln'] = { + 'scale': pytorch_model.out_norm.weight.detach().numpy() + } + + # Handle output projection (tied or untied) + if not config.tie_embeddings: + new_params['params']['output_proj'] = { + 'kernel': pytorch_model.lm_head.weight.detach().numpy().T + } + + return new_params + + +def test_full_model(): + """Test that full Transformer model produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Full Transformer Model') + logging.info('=' * 70) + + batch_size, seq_len = 2, 32 + vocab_size = 256 + dim = 128 + n_heads = 4 + n_layers = 2 + expand = 4.0 + + # Initialize PyTorch model + pytorch_config = PyTorchModelConfig( + vocab_size=vocab_size, + seq_len=seq_len, + model_dim=dim, + expanded_model_dim=int(dim * expand), + num_layers=n_layers, + num_heads=n_heads, + rmsnorm_epsilon=1e-6, + tie_embeddings=True, + ) + pytorch_model = Transformer(pytorch_config) + pytorch_model.eval() + + # Initialize JAX model + jax_config = JaxModelConfig( + model_dim=dim, + num_heads=n_heads, + seq_len=seq_len, + num_layers=n_layers, + vocab_size=vocab_size, + expanded_model_dim=int(dim * expand), + dtype=jnp.float32, + rmsnorm_epsilon=1e-6, + tie_embeddings=True, + ) + jax_model = TransformerDo(jax_config) + + # Create input tokens + np_tokens = np.random.randint( + 0, vocab_size, size=(batch_size, seq_len), dtype=np.int32 + ) + torch_tokens = torch.tensor(np_tokens, dtype=torch.long) + jax_tokens = jnp.array(np_tokens, dtype=jnp.int32) + + # Initialize JAX params + jax_params = jax_model.init(jax.random.PRNGKey(0), jax_tokens) + + # Copy weights from PyTorch to JAX + jax_params = copy_full_model_params(pytorch_model, jax_params, pytorch_config) + + # Forward pass + with torch.no_grad(): + torch_logits = pytorch_model(torch_tokens) + + jax_logits = jax_model.apply(jax_params, jax_tokens) + + # Compare + assert_close( + jax_logits, torch_logits, rtol=1e-4, atol=1e-5, name='Full Model' + ) + logging.info('✓ Full Model test passed') + + +def test_prediction(): + """Test that autoregressive generation produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Autoregressive Prediction') + logging.info('=' * 70) + + batch_size, seq_len = 1, 10 + vocab_size = 256 + dim = 128 + n_heads = 4 + n_layers = 2 + expand = 4.0 + k = 5 # Number of tokens to predict + + # Initialize PyTorch model + pytorch_config = PyTorchModelConfig( + vocab_size=vocab_size, + seq_len=seq_len + k, + model_dim=dim, + expanded_model_dim=int(dim * expand), + num_layers=n_layers, + num_heads=n_heads, + rmsnorm_epsilon=1e-6, + tie_embeddings=True, + ) + pytorch_model = Transformer(pytorch_config) + pytorch_model.eval() + + # Initialize JAX model + jax_config = JaxModelConfig( + model_dim=dim, + num_heads=n_heads, + seq_len=seq_len + k, + num_layers=n_layers, + vocab_size=vocab_size, + expanded_model_dim=int(dim * expand), + dtype=jnp.float32, + rmsnorm_epsilon=1e-6, + tie_embeddings=True, + ) + jax_model = TransformerDo(jax_config) + + # Create input tokens + np_tokens = np.random.randint( + 0, vocab_size, size=(batch_size, seq_len), dtype=np.int32 + ) + torch_tokens = torch.tensor(np_tokens, dtype=torch.long) + jax_tokens = jnp.array(np_tokens, dtype=jnp.int32) + + # Initialize JAX params + jax_params = jax_model.init(jax.random.PRNGKey(0), jax_tokens) + + # Copy weights from PyTorch to JAX + jax_params = copy_full_model_params(pytorch_model, jax_params, pytorch_config) + + # Predict k tokens + with torch.no_grad(): + _, torch_predictions = pytorch_model.predict(torch_tokens, k=k) + + _, jax_predictions = jax_model.apply( + jax_params, jax_tokens, k, method=jax_model.predict + ) + + # Compare predictions + torch_pred_np = torch_predictions.cpu().numpy() + jax_pred_np = np.array(jax_predictions) + + logging.info(f'\nPyTorch predictions: {torch_pred_np[0]}') + logging.info(f'JAX predictions: {jax_pred_np[0]}') + + # Check if predictions match exactly + if np.array_equal(torch_pred_np, jax_pred_np): + logging.info('✓ Predictions match exactly!') + else: + matching = np.sum(torch_pred_np == jax_pred_np) + total = torch_pred_np.size + logging.info( + f'⚠ Predictions differ: {matching}/{total} tokens match ({matching / total * 100:.1f}%)' + ) + logging.info( + ' (Note: Small numerical differences can lead to different argmax results)' + ) + + +def test_initialization_statistics(): + """Verify initialization follows expected distributions.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Initialization Statistics') + logging.info('=' * 70) + + # Initialize models + jax_cfg = JaxModelConfig( + model_dim=512, + num_heads=8, + seq_len=1024, + num_layers=12, + vocab_size=50000, + expanded_model_dim=2048, + dtype=jnp.float32, + ) + jax_model = TransformerDo(jax_cfg) + jax_params = jax_model.init( + jax.random.PRNGKey(42), jnp.ones((1, 10), dtype=jnp.int32) + ) + + pytorch_cfg = PyTorchModelConfig( + vocab_size=50000, + seq_len=1024, + model_dim=512, + expanded_model_dim=2048, + num_layers=12, + num_heads=8, + ) + pytorch_model = Transformer(pytorch_cfg) + + logging.info('Initialization Statistics Check:') + + # Check embedding + jax_embed = jax_params['params']['embed']['embedding'] + torch_embed = pytorch_model.embed_tokens.weight.detach().numpy() + + logging.info('\nToken Embedding (should be ~0.02 std):') + logging.info( + f' JAX: mean={jax_embed.mean():.6f}, std={jax_embed.std():.6f}' + ) + logging.info( + f' PyTorch: mean={torch_embed.mean():.6f}, std={torch_embed.std():.6f}' + ) + + # Assert embedding std is close to 0.02 + assert abs(jax_embed.std() - 0.02) < 0.005, ( + f'JAX embedding std {jax_embed.std():.6f} not close to 0.02' + ) + assert abs(torch_embed.std() - 0.02) < 0.005, ( + f'PyTorch embedding std {torch_embed.std():.6f} not close to 0.02' + ) + assert abs(jax_embed.mean()) < 0.01, ( + f'JAX embedding mean {jax_embed.mean():.6f} not close to 0' + ) + assert abs(torch_embed.mean()) < 0.01, ( + f'PyTorch embedding mean {torch_embed.mean():.6f} not close to 0' + ) + + # Check first layer attention Q + jax_q = jax_params['params']['blocks_0']['CausalAttn_0']['query']['kernel'] + torch_q_weight = ( + pytorch_model.layers[0].attn.w_qkv.weight[:512].detach().numpy() + ) + + logging.info('\nAttention Q:') + logging.info(f' JAX: mean={jax_q.mean():.6f}, std={jax_q.std():.6f}') + logging.info( + f' PyTorch: mean={torch_q_weight.mean():.6f}, std={torch_q_weight.std():.6f}' + ) + + # Check means are close to 0 + assert abs(jax_q.mean()) < 0.01, ( + f'JAX Q mean {jax_q.mean():.6f} not close to 0' + ) + assert abs(torch_q_weight.mean()) < 0.01, ( + f'PyTorch Q mean {torch_q_weight.mean():.6f} not close to 0' + ) + + # Check stds are similar + # Allow 20% difference due to random initialization + assert abs(jax_q.std() - torch_q_weight.std()) / torch_q_weight.std() < 0.2, ( + f'Q std differs too much: JAX {jax_q.std():.6f} vs PyTorch {torch_q_weight.std():.6f}' + ) + + # Check first layer attention output (should be scaled) + jax_attn_out = jax_params['params']['blocks_0']['CausalAttn_0'][ + 'attn_out_proj' + ]['kernel'] + torch_attn_out = pytorch_model.layers[0].attn.w_out.weight.detach().numpy() + + logging.info('\nAttention Output:') + logging.info( + f' JAX: mean={jax_attn_out.mean():.6f}, std={jax_attn_out.std():.6f}' + ) + logging.info( + f' PyTorch: mean={torch_attn_out.mean():.6f}, std={torch_attn_out.std():.6f}' + ) + + # Check means are close to 0 + assert abs(jax_attn_out.mean()) < 0.01, ( + f'JAX attn out mean {jax_attn_out.mean():.6f} not close to 0' + ) + assert abs(torch_attn_out.mean()) < 0.01, ( + f'PyTorch attn out mean {torch_attn_out.mean():.6f} not close to 0' + ) + + # Check stds are similar + assert ( + abs(jax_attn_out.std() - torch_attn_out.std()) / torch_attn_out.std() < 0.2 + ), ( + f'Attention output std differs too much: JAX {jax_attn_out.std():.6f} vs PyTorch {torch_attn_out.std():.6f}' + ) + + # Check MLP fc2 (should be scaled) + jax_mlp_out = jax_params['params']['blocks_0']['Mlp_0']['Dense_1']['kernel'] + torch_mlp_out = pytorch_model.layers[0].mlp.fc2.weight.detach().numpy() + + logging.info('\nMLP Output:') + logging.info( + f' JAX: mean={jax_mlp_out.mean():.6f}, std={jax_mlp_out.std():.6f}' + ) + logging.info( + f' PyTorch: mean={torch_mlp_out.mean():.6f}, std={torch_mlp_out.std():.6f}' + ) + + # Check means are close to 0 + assert abs(jax_mlp_out.mean()) < 0.01, ( + f'JAX MLP out mean {jax_mlp_out.mean():.6f} not close to 0' + ) + assert abs(torch_mlp_out.mean()) < 0.01, ( + f'PyTorch MLP out mean {torch_mlp_out.mean():.6f} not close to 0' + ) + + # Check stds are similar + assert ( + abs(jax_mlp_out.std() - torch_mlp_out.std()) / torch_mlp_out.std() < 0.2 + ), ( + f'MLP output std differs too much: JAX {jax_mlp_out.std():.6f} vs PyTorch {torch_mlp_out.std():.6f}' + ) + + logging.info('\n✓ Initialization statistics test passed') + + +def test_initialization_impact(): + """Test that initialization produces similar initial losses.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Initialization Impact') + logging.info('=' * 70) + + # Create identical inputs + batch_size, seq_len = 4, 128 + vocab_size = 50000 + + np.random.seed(42) + tokens = np.random.randint(0, vocab_size, size=(batch_size, seq_len)) + + # Initialize both models with same seed + jax_cfg = JaxModelConfig( + model_dim=512, + num_heads=8, + seq_len=seq_len, + num_layers=12, + vocab_size=vocab_size, + expanded_model_dim=2048, + ) + jax_model = TransformerDo(jax_cfg) + jax_params = jax_model.init( + jax.random.PRNGKey(42), jnp.array(tokens, dtype=jnp.int32) + ) + + torch.manual_seed(42) + pytorch_cfg = PyTorchModelConfig( + vocab_size=vocab_size, + seq_len=seq_len, + model_dim=512, + expanded_model_dim=2048, + num_layers=12, + num_heads=8, + ) + pytorch_model = Transformer(pytorch_cfg) + + # Forward pass + jax_logits = jax_model.apply(jax_params, jnp.array(tokens, dtype=jnp.int32)) + + with torch.no_grad(): + torch_logits = pytorch_model(torch.tensor(tokens, dtype=torch.long)) + + # Compute losses + targets = tokens[:, 1:] + jax_loss = -jax.nn.log_softmax(jax_logits[:, :-1]).mean() + torch_loss = F.cross_entropy( + torch_logits[:, :-1].reshape(-1, vocab_size), + torch.tensor(targets.reshape(-1), dtype=torch.long), + ) + + logging.info('\nInitial Loss Comparison:') + logging.info(f' JAX: {jax_loss:.4f}') + logging.info(f' PyTorch: {torch_loss.item():.4f}') + logging.info(f' Difference: {abs(jax_loss - torch_loss.item()):.6f}') + + # Check that losses are in reasonable range for random init + # With vocab_size=50000, random init should give loss around log(50000) ≈ 10.82 + expected_loss = np.log(vocab_size) + + assert 8.0 < jax_loss < 13.0, ( + f'JAX loss {jax_loss:.4f} outside expected range [8.0, 13.0]' + ) + assert 8.0 < torch_loss.item() < 13.0, ( + f'PyTorch loss {torch_loss.item():.4f} outside expected range [8.0, 13.0]' + ) + + # Both losses should be within 10% of log(vocab_size) + assert abs(jax_loss - expected_loss) / expected_loss < 0.1, ( + f'JAX loss {jax_loss:.4f} too far from expected {expected_loss:.4f}' + ) + assert abs(torch_loss.item() - expected_loss) / expected_loss < 0.1, ( + f'PyTorch loss {torch_loss.item():.4f} too far from expected {expected_loss:.4f}' + ) + + logging.info( + '\nNote: Losses are in expected range for random initialization.' + ) + logging.info(f' Expected ~log(vocab_size) = {expected_loss:.4f}') + logging.info('\n✓ Initialization impact test passed') + + +# ============================================================================ +# Test Class +# ============================================================================ + +named_parameters = [ + dict(testcase_name='rmsnorm', test_fn=test_rmsnorm), + dict(testcase_name='rope', test_fn=test_rope), + dict(testcase_name='mlp', test_fn=test_mlp), + dict(testcase_name='attention', test_fn=test_attention), + dict(testcase_name='block', test_fn=test_block), + dict(testcase_name='full_model', test_fn=test_full_model), + dict(testcase_name='prediction', test_fn=test_prediction), + dict( + testcase_name='initialization_statistics', + test_fn=test_initialization_statistics, + ), + dict( + testcase_name='initialization_impact', test_fn=test_initialization_impact + ), +] + + +class ModelMatchingTest(parameterized.TestCase): + """Tests for JAX vs PyTorch model matching.""" + + @parameterized.named_parameters(*named_parameters) + def test_model_matching(self, test_fn): + """Run individual model matching test.""" + test_fn() + + +if __name__ == '__main__': + absltest.main()