Skip to content

pattern for module with no parameters or variables like positional encoding #2418

Answered by marcvanzee
aldopareja asked this question in Q&A
Discussion options

You must be logged in to vote

It seems your array pe is independent on your input so it is treated as a constant. The best place to store constants is to just store them as separate variable collections. So like this:

from jax import random
import jax
import jax.numpy as jnp
from flax import linen as nn

class PositionalEncoder(nn.Module):
  d_model: int
  max_length: int

  @staticmethod
  def init_pe(d_model: int, max_length: int):
      positions = jnp.arange(max_length)[:, None]
      div_term = jnp.exp(jnp.arange(0, d_model, 2) * (-jnp.log(10000.0)/d_model))
      
      temp = positions * div_term
      even_mask = positions % 2 == 0
      
      pe = jnp.where(even_mask, jnp.sin(temp), jnp.cos(temp))
      
      

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@aldopareja
Comment options

Answer selected by aldopareja
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants