pattern for module with no parameters or variables like positional encoding #2418
-
Hi, I haven't been able to come up with the proper pattern to create a module that has no variables or parameters. In the case of positional encoding, the only variable is a static array that can stay immutable (so not necessary to make it a variable??) and it's also not a parameter. So far this code gives several errors related to the class PositionalEncoder(nn.Module):
d_model: int
max_length: int
pe: jnp.DeviceArray = field(init=False)
@staticmethod
def init_pe(d_model: int, max_length: int):
positions = jnp.arange(max_length)[:, None]
div_term = jnp.exp(jnp.arange(0, d_model, 2) * (-jnp.log(10000.0)/d_model))
temp = positions * div_term
even_mask = positions % 2 == 0
pe = jnp.where(even_mask, jnp.sin(temp), jnp.cos(temp))
return pe
def __post_init__(self):
self.pe = self.init_pe(self.d_model, self.max_length)
@nn.jit
@nn.compact
def __call__(self, x):
return x + self.pe[:x.shape[0]] So, I'm trying to use pe as a static, immutable array that get's reused. JITing it and telling the compiler that the array will stay the same should increase performance quite a bit. What would be the best pattern to do this and still be able to use this as a submodule of other bigger Modules (like a transformer?). Thank you! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
It seems your array from jax import random
import jax
import jax.numpy as jnp
from flax import linen as nn
class PositionalEncoder(nn.Module):
d_model: int
max_length: int
@staticmethod
def init_pe(d_model: int, max_length: int):
positions = jnp.arange(max_length)[:, None]
div_term = jnp.exp(jnp.arange(0, d_model, 2) * (-jnp.log(10000.0)/d_model))
temp = positions * div_term
even_mask = positions % 2 == 0
pe = jnp.where(even_mask, jnp.sin(temp), jnp.cos(temp))
return pe
@nn.compact
def __call__(self, x):
pe = self.variable('consts', 'pe', PositionalEncoder.init_pe, self.d_model, self.max_length)
return x + pe.value[:x.shape[0]]
m = PositionalEncoder(64, 12)
variables = m.init(random.PRNGKey(0), jnp.zeros((2, 32)))
jax.tree_map(jnp.shape, variables)
When calling |
Beta Was this translation helpful? Give feedback.
It seems your array
pe
is independent on your input so it is treated as a constant. The best place to store constants is to just store them as separate variable collections. So like this: