Skip to content

Commit

Permalink
Merge pull request #3815 from google:nnx-dropout-optional-state
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623149816
  • Loading branch information
Flax Authors committed Apr 9, 2024
2 parents 91cfc2a + 1a1ca48 commit 2718455
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 18 deletions.
10 changes: 0 additions & 10 deletions flax/experimental/nnx/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,6 @@ def _module_meta_call(cls: tp.Type[M], *args, **kwargs) -> M:
if isinstance(module, _HasSetup):
module.setup()

assert isinstance(module, Module)

for field in dataclasses.fields(module):
if not field.init:
continue
value = vars(module)[field.name]
# set Rngs instances to None
if isinstance(value, Rngs):
vars(module)[field.name] = None

return module


Expand Down
27 changes: 22 additions & 5 deletions flax/experimental/nnx/nnx/nn/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Sequence
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import dataclasses
from typing import Sequence

import jax
import jax.numpy as jnp
from jax import lax, random

from flax.experimental.nnx.nnx import rnglib
from flax.experimental.nnx.nnx.module import Module, first_from
import dataclasses


@dataclasses.dataclass
Expand All @@ -38,15 +53,16 @@ class Dropout(Module):

rate: float
broadcast_dims: Sequence[int] = ()
deterministic: Optional[bool] = None
deterministic: bool | None = None
rng_collection: str = 'dropout'
rngs: rnglib.Rngs | None = None

def __call__(
self,
inputs,
*,
deterministic: Optional[bool] = None,
rngs: Optional[rnglib.Rngs] = None,
deterministic: bool | None = None,
rngs: rnglib.Rngs | None = None,
) -> jax.Array:
"""Applies a random dropout mask to the input.
Expand All @@ -59,6 +75,7 @@ def __call__(
Returns:
The masked inputs reweighted to preserve mean.
"""
rngs = rngs or self.rngs
deterministic = first_from(
deterministic,
self.deterministic,
Expand Down
41 changes: 41 additions & 0 deletions flax/experimental/nnx/tests/nn/test_stochastic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import jax.numpy as jnp

from flax.experimental import nnx


class TestStochastic:
def test_dropout_internal_rngs(self):
n = 0
m = nnx.Dropout(rate=0.5, deterministic=False, rngs=nnx.Rngs(dropout=0))

@nnx.jit
def f(m, x):
nonlocal n
n += 1
return m(x)

x = jnp.ones((1, 10))
assert m.rngs is not None and m.rngs.dropout.count.value == 0

y = f(m, x)
assert n == 1
assert m.rngs.dropout.count.value == 1

y = f(m, x)
assert n == 1
assert m.rngs.dropout.count.value == 2
4 changes: 1 addition & 3 deletions flax/experimental/nnx/tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ class Foo(nnx.Module):
assert state.d == nnx.Variable(4)
assert state.e == nnx.BatchStat(5)

def test_context_none_after_init(self):
def test_post_init(self):
@dataclasses.dataclass
class DFoo(nnx.Module):
din: int
Expand All @@ -566,7 +566,6 @@ def __call__(self, x):
m = DFoo(1, 1, rngs=nnx.Rngs(0))

assert hasattr(m, 'bar')
assert m.rngs is None

def test_setup_is_called(self):
@dataclasses.dataclass
Expand All @@ -584,7 +583,6 @@ def __call__(self, x):
m = DFoo(1, 1, rngs=nnx.Rngs(0))

assert hasattr(m, 'bar')
assert m.rngs is None


class TestModuleDef:
Expand Down

0 comments on commit 2718455

Please sign in to comment.