Skip to content

Half precision in flax models #2385

Answered by mar-muel
mar-muel asked this question in Q&A
Discussion options

You must be logged in to vote

Nevermind - I just found found the solution to use param_dtype on the modules! So using

class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x): 
    for feat in self.features[:-1]:
      x = nn.relu(nn.Dense(feat, dtype=jnp.bfloat16, param_dtype=jnp.bfloat16)(x))
    x = nn.Dense(self.features[-1], dtype=jnp.bfloat16, param_dtype=jnp.bfloat16)(x)
    return x

Can be closed.

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant
Converted from issue

This discussion was converted from issue #2384 on August 08, 2022 09:54.