Skip to content

Commit

Permalink
[nnx] add Dropout.rngs
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Apr 2, 2024
1 parent 30fb7ff commit 1a1ca48
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 18 deletions.
10 changes: 0 additions & 10 deletions flax/experimental/nnx/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,6 @@ def _module_meta_call(cls: tp.Type[M], *args, **kwargs) -> M:
if isinstance(module, _HasSetup):
module.setup()

assert isinstance(module, Module)

for field in dataclasses.fields(module):
if not field.init:
continue
value = vars(module)[field.name]
# set Rngs instances to None
if isinstance(value, Rngs):
vars(module)[field.name] = None

return module


Expand Down
13 changes: 8 additions & 5 deletions flax/experimental/nnx/nnx/nn/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import Optional, Sequence
import dataclasses
from typing import Sequence

import jax
import jax.numpy as jnp
from jax import lax, random

from flax.experimental.nnx.nnx import rnglib
from flax.experimental.nnx.nnx.module import Module, first_from
import dataclasses


@dataclasses.dataclass
Expand All @@ -38,15 +39,16 @@ class Dropout(Module):

rate: float
broadcast_dims: Sequence[int] = ()
deterministic: Optional[bool] = None
deterministic: bool | None = None
rng_collection: str = 'dropout'
rngs: rnglib.Rngs | None = None

def __call__(
self,
inputs,
*,
deterministic: Optional[bool] = None,
rngs: Optional[rnglib.Rngs] = None,
deterministic: bool | None = None,
rngs: rnglib.Rngs | None = None,
) -> jax.Array:
"""Applies a random dropout mask to the input.
Expand All @@ -59,6 +61,7 @@ def __call__(
Returns:
The masked inputs reweighted to preserve mean.
"""
rngs = rngs or self.rngs
deterministic = first_from(
deterministic,
self.deterministic,
Expand Down
27 changes: 27 additions & 0 deletions flax/experimental/nnx/tests/nn/test_stochastic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

import jax.numpy as jnp

from flax.experimental import nnx


class TestStochastic:
def test_dropout_internal_rngs(self):
n = 0
m = nnx.Dropout(rate=0.5, deterministic=False, rngs=nnx.Rngs(dropout=0))

@nnx.jit
def f(m, x):
nonlocal n
n += 1
return m(x)

x = jnp.ones((1, 10))
assert m.rngs is not None and m.rngs.dropout.count.value == 0

y = f(m, x)
assert n == 1
assert m.rngs.dropout.count.value == 1

y = f(m, x)
assert n == 1
assert m.rngs.dropout.count.value == 2
4 changes: 1 addition & 3 deletions flax/experimental/nnx/tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ class Foo(nnx.Module):
assert state.d == nnx.Variable(4)
assert state.e == nnx.BatchStat(5)

def test_context_none_after_init(self):
def test_post_init(self):
@dataclasses.dataclass
class DFoo(nnx.Module):
din: int
Expand All @@ -566,7 +566,6 @@ def __call__(self, x):
m = DFoo(1, 1, rngs=nnx.Rngs(0))

assert hasattr(m, 'bar')
assert m.rngs is None

def test_setup_is_called(self):
@dataclasses.dataclass
Expand All @@ -584,7 +583,6 @@ def __call__(self, x):
m = DFoo(1, 1, rngs=nnx.Rngs(0))

assert hasattr(m, 'bar')
assert m.rngs is None


class TestModuleDef:
Expand Down

0 comments on commit 1a1ca48

Please sign in to comment.