Skip to content

Commit

Permalink
Merge pull request #3302 from chiamp:lm1b
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 577295967
  • Loading branch information
Flax Authors committed Oct 27, 2023
2 parents 416b3e2 + 558be01 commit 8db32ae
Show file tree
Hide file tree
Showing 6 changed files with 371 additions and 85 deletions.
47 changes: 42 additions & 5 deletions examples/lm1b/configs/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ def get_config():
config.max_corpus_chars = 10**7

# Name of TFDS translation dataset to use.
config.dataset_name = "lm1b"
config.dataset_name = 'lm1b'

# Optional name of TFDS translation dataset to use for evaluation.
config.eval_dataset_name = "lm1b"
config.eval_split = "test"
config.eval_dataset_name = 'lm1b'
config.eval_split = 'test'

# Per device batch size for training.
config.per_device_batch_size = 32
Expand Down Expand Up @@ -114,7 +114,44 @@ def get_config():
# Integer for PRNG random seed.
config.seed = 0

# Prompt for language model sampling.
config.prompts = "I love to "
# Prompt for language model sampling,
# taken from MaxText (https://github.com/google/maxtext/blob/main/MaxText/configs/base.yml).
config.prompts = 'I love to '

# Parallelism
config.mesh_axes = ['data', 'fsdp', 'tensor']
config.logical_axis_rules = [
['activation_batch', ['data', 'fsdp']],
['activation_length', ['data', 'fsdp']],
['activation_embed', 'tensor'],
['activation_mlp', 'tensor'],
['activation_heads', 'tensor'],
['activation_kv', 'tensor'],
['activation_vocab', 'tensor'],
['mlp', 'tensor'],
['vocab', 'tensor'],
['embed', 'fsdp'],
['heads', 'tensor'],
]
config.full_sharding = ['data', 'fsdp', 'tensor']
config.data_sharding = ['data']

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
# By default, product of the DCN axes should equal number of slices
# and product of the ICI axes should equal number of devices per slice.
# ICI (Inter-Chip Interconnection): A high-speed connection between
# sets of TPU chips, which form the TPU network.
# DCN (Data Center Network): A connection between the TPU networks;
# not as fast as ICI.
# ICI has around 100x the bandwidth of DCN, but it is not a general
# purpose connection, which is why DCN is necessary for scaling to
# extremely large ML models.
config.dcn_data_parallelism = -1 # recommended DCN axis to be auto-sharded
config.dcn_fsdp_parallelism = 1
config.dcn_tensor_parallelism = 1
config.ici_data_parallelism = 1
config.ici_fsdp_parallelism = -1 # recommended ICI axis to be auto-sharded
config.ici_tensor_parallelism = 1

return config
59 changes: 47 additions & 12 deletions examples/lm1b/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,10 @@ def __call__(self, inputs):
x = nn.Dense(
config.mlp_dim,
dtype=config.dtype,
kernel_init=config.kernel_init,
bias_init=config.bias_init,
kernel_init=nn.with_logical_partitioning(
config.kernel_init, ('embed', 'mlp')
),
bias_init=nn.with_logical_partitioning(config.bias_init, ('mlp',)),
)(inputs)
x = nn.relu(x)
x = nn.Dropout(rate=config.dropout_rate)(
Expand All @@ -196,8 +198,10 @@ def __call__(self, inputs):
output = nn.Dense(
actual_out_dim,
dtype=config.dtype,
kernel_init=config.kernel_init,
bias_init=config.bias_init,
kernel_init=nn.with_logical_partitioning(
config.kernel_init, ('mlp', 'embed')
),
bias_init=nn.with_logical_partitioning(config.bias_init, ('embed',)),
)(x)
output = nn.Dropout(rate=config.dropout_rate)(
output, deterministic=config.deterministic
Expand Down Expand Up @@ -230,13 +234,23 @@ def __call__(self, inputs, decoder_mask=None, encoder_decoder_mask=None):

# Decoder block.
assert inputs.ndim == 3
x = nn.LayerNorm(dtype=config.dtype)(inputs)
x = nn.LayerNorm(
dtype=config.dtype,
bias_init=nn.with_logical_partitioning(
nn.initializers.zeros, ('embed',)
),
scale_init=nn.with_logical_partitioning(
nn.initializers.ones, ('embed',)
),
)(inputs)
x = nn.SelfAttention(
num_heads=config.num_heads,
dtype=config.dtype,
qkv_features=config.qkv_dim,
kernel_init=config.kernel_init,
bias_init=config.bias_init,
kernel_init=nn.with_logical_partitioning(
config.kernel_init, ('embed', 'kv')
),
bias_init=nn.with_logical_partitioning(config.bias_init, ('embed',)),
use_bias=False,
broadcast_dropout=False,
dropout_rate=config.attention_dropout_rate,
Expand All @@ -249,7 +263,15 @@ def __call__(self, inputs, decoder_mask=None, encoder_decoder_mask=None):
x = x + inputs

# MLP block.
z = nn.LayerNorm(dtype=config.dtype)(x)
z = nn.LayerNorm(
dtype=config.dtype,
bias_init=nn.with_logical_partitioning(
nn.initializers.zeros, ('embed',)
),
scale_init=nn.with_logical_partitioning(
nn.initializers.ones, ('embed',)
),
)(x)
z = MlpBlock(config=config)(z)

return x + z
Expand Down Expand Up @@ -296,7 +318,9 @@ def __call__(
output_embed = nn.Embed(
num_embeddings=config.output_vocab_size,
features=config.emb_dim,
embedding_init=nn.initializers.normal(stddev=1.0),
embedding_init=nn.with_logical_partitioning(
nn.initializers.normal(stddev=1.0), ('vocab', 'embed')
),
)
else:
output_embed = self.shared_embedding
Expand All @@ -319,7 +343,16 @@ def __call__(
y = EncoderDecoder1DBlock(
config=config, name=f'encoderdecoderblock_{lyr}'
)(y, decoder_mask=decoder_mask, encoder_decoder_mask=encoder_decoder_mask)
y = nn.LayerNorm(dtype=config.dtype, name='encoderdecoder_norm')(y)
y = nn.LayerNorm(
dtype=config.dtype,
name='encoderdecoder_norm',
bias_init=nn.with_logical_partitioning(
nn.initializers.zeros, ('embed',)
),
scale_init=nn.with_logical_partitioning(
nn.initializers.ones, ('embed',)
),
)(y)

# Decoded Logits
if config.logits_via_embedding:
Expand All @@ -331,8 +364,10 @@ def __call__(
logits = nn.Dense(
config.output_vocab_size,
dtype=config.dtype,
kernel_init=config.kernel_init,
bias_init=config.bias_init,
kernel_init=nn.with_logical_partitioning(
config.kernel_init, ('embed', 'vocab')
),
bias_init=nn.with_logical_partitioning(config.bias_init, ('vocab',)),
name='logitdense',
)(y)
return logits
Expand Down
24 changes: 12 additions & 12 deletions examples/lm1b/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
absl-py==1.0.0
clu==0.0.6
flax==0.4.1
jax==0.3.4
absl-py==1.4.0
clu==0.0.9
flax==0.6.11
jax==0.4.13
--find-links https://storage.googleapis.com/jax-releases/jax_releases.html
jaxlib==0.3.2+cuda11.cudnn82 # Make sure CUDA version matches the base image.
ml-collections==0.1.0
numpy==1.22.0
optax==0.1.0
sentencepiece==0.1.96
tensorflow==2.11.1
tensorflow-datasets==4.4.0
tensorflow-text==2.8.1
jaxlib==0.4.13+cuda11.cudnn82 # Make sure CUDA version matches the base image.
ml-collections==0.1.1
numpy==1.24.3
optax==0.1.5
sentencepiece==0.1.99
tensorflow==2.13.0
tensorflow-datasets==4.9.2
tensorflow-text==2.13.0
2 changes: 1 addition & 1 deletion examples/lm1b/temperature_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class TestTemperatureSampler(absltest.TestCase):
def test_temperature_sampler(self):
tokens = jnp.array([[5, 0, 0, 0]], dtype=jnp.int32)
cache = None
key = jax.random.key(0)
key = jax.random.PRNGKey(0)

def tokens_to_logits(tokens, cache):
jax.debug.print('tokens: {}', tokens)
Expand Down
Loading

0 comments on commit 8db32ae

Please sign in to comment.