Skip to content

Commit

Permalink
Remove non-lazy RNG compat mode and flag from flax.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 691326667
  • Loading branch information
levskaya authored and Flax Authors committed Oct 30, 2024
1 parent bcc12b5 commit b2d23e4
Show file tree
Hide file tree
Showing 4 changed files with 1 addition and 27 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ vNext
-
-
-
-
- removed FLAX_LAZY_RNG flag support for old non-lazy PRNG derivation mode
-
-
-
Expand Down
3 changes: 0 additions & 3 deletions flax/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,6 @@ def temp_flip_flag(var_name: str, var_value: bool):

# Flax Global Configuration Variables:

# Whether to use the lazy rng implementation.
flax_lazy_rng = static_bool_env('FLAX_LAZY_RNG', True)

flax_filter_frames = bool_flag(
name='flax_filter_frames',
default=True,
Expand Down
22 changes: 0 additions & 22 deletions flax/core/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from jax import random, tree_util

from flax import config as config
from flax import configurations as legacy_config # only for flax_lazy_rng
from flax import errors, struct, traceback_util
from flax.ids import uuid
from flax.typing import (
Expand Down Expand Up @@ -98,11 +97,6 @@ def as_jax_rng(self) -> PRNGKey:
def create(
rng: Union['LazyRng', PRNGKey], *suffix: PRNGFoldable
) -> 'LazyRng':
if not legacy_config.flax_lazy_rng:
if isinstance(rng, LazyRng):
assert not rng.suffix
rng = rng.rng
return LazyRng(_legacy_rng_fold_in(rng, suffix), ())
if isinstance(rng, LazyRng):
return LazyRng(rng.rng, rng.suffix + suffix)
else:
Expand All @@ -113,22 +107,6 @@ def fold(self):
return LazyRng(key, ())


def _legacy_rng_fold_in(rng: PRNGKey, data: Iterable[PRNGFoldable]) -> PRNGKey:
"""Legacy RNG folding."""
for x in data:
if isinstance(x, str):
m = hashlib.sha1()
m.update(x.encode('utf-8'))
d = m.digest()
hash_int = int.from_bytes(d[:4], byteorder='big')
rng = random.fold_in(rng, jnp.uint32(hash_int)) # type: ignore
elif isinstance(x, int):
rng = random.fold_in(rng, x)
else:
raise ValueError(f'Expected int or string, got: {x}')
return rng


def _fold_in_static(
rng: PRNGKey, data: typing.Collection[PRNGFoldable]
) -> PRNGKey:
Expand Down
1 change: 0 additions & 1 deletion tests/run_all_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ assert_error="flax is not running on editable mode."
# env vars must be set after doctest
export JAX_NUMPY_RANK_PROMOTION=raise
export FLAX_PROFILE=1
export FLAX_LAZY_RNG=1

if $RUN_PYTEST; then
echo "=== RUNNING PYTESTS ==="
Expand Down

0 comments on commit b2d23e4

Please sign in to comment.