-
Hello team I'm trying to run experiments in half precision (bfloat16) on TPU v3-8 VMs. I'm experiencing the following behaviour and I just wanted to check with you guys whether this expected? from typing import Sequence
import jax
import jax.numpy as jnp
import flax.linen as nn
class MLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, x):
for feat in self.features[:-1]:
x = nn.relu(nn.Dense(feat, dtype=jnp.bfloat16)(x))
x = nn.Dense(self.features[-1], dtype=jnp.bfloat16)(x)
return x
model = MLP([12, 8, 4])
batch = jnp.ones((32, 10), dtype=jnp.bfloat16)
variables = model.init(jax.random.PRNGKey(0), batch)
output = model.apply(variables, batch)
print('Input dtype: ', batch.dtype)
print('Output dtype: ', output.dtype)
print('Model dtype: ', jax.tree_util.tree_map(lambda s: s.dtype, variables)) Output:
I understand that it is somewhat uncommon to keep the model weights in half precision (especially for batch norm layers). But how would one force the model weights to be bfloat16 (say in case the model is too large to fit on a single device)? I guess one could explicitly cast all variables afterwards
Casting params results (as expected) in Thanks for your help! Versions:
Python version 3.8.13 |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Nevermind - I just found found the solution to use
Can be closed. |
Beta Was this translation helpful? Give feedback.
Nevermind - I just found found the solution to use
param_dtype
on the modules! So usingCan be closed.