From b2d23e4885fbab91cb8856185cc7073f11ce0ed0 Mon Sep 17 00:00:00 2001 From: Anselm Levskaya Date: Wed, 30 Oct 2024 01:30:25 -0700 Subject: [PATCH] Remove non-lazy RNG compat mode and flag from flax. PiperOrigin-RevId: 691326667 --- CHANGELOG.md | 2 +- flax/configurations.py | 3 --- flax/core/scope.py | 22 ---------------------- tests/run_all_tests.sh | 1 - 4 files changed, 1 insertion(+), 27 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index edeee4155f..7f6aeccf92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ vNext - - - -- +- removed FLAX_LAZY_RNG flag support for old non-lazy PRNG derivation mode - - - diff --git a/flax/configurations.py b/flax/configurations.py index 4f61170f16..ba19a572fc 100644 --- a/flax/configurations.py +++ b/flax/configurations.py @@ -162,9 +162,6 @@ def temp_flip_flag(var_name: str, var_value: bool): # Flax Global Configuration Variables: -# Whether to use the lazy rng implementation. -flax_lazy_rng = static_bool_env('FLAX_LAZY_RNG', True) - flax_filter_frames = bool_flag( name='flax_filter_frames', default=True, diff --git a/flax/core/scope.py b/flax/core/scope.py index 836dbdef36..b46affcbf3 100644 --- a/flax/core/scope.py +++ b/flax/core/scope.py @@ -38,7 +38,6 @@ from jax import random, tree_util from flax import config as config -from flax import configurations as legacy_config # only for flax_lazy_rng from flax import errors, struct, traceback_util from flax.ids import uuid from flax.typing import ( @@ -98,11 +97,6 @@ def as_jax_rng(self) -> PRNGKey: def create( rng: Union['LazyRng', PRNGKey], *suffix: PRNGFoldable ) -> 'LazyRng': - if not legacy_config.flax_lazy_rng: - if isinstance(rng, LazyRng): - assert not rng.suffix - rng = rng.rng - return LazyRng(_legacy_rng_fold_in(rng, suffix), ()) if isinstance(rng, LazyRng): return LazyRng(rng.rng, rng.suffix + suffix) else: @@ -113,22 +107,6 @@ def fold(self): return LazyRng(key, ()) -def _legacy_rng_fold_in(rng: PRNGKey, data: Iterable[PRNGFoldable]) -> PRNGKey: - """Legacy RNG folding.""" - for x in data: - if isinstance(x, str): - m = hashlib.sha1() - m.update(x.encode('utf-8')) - d = m.digest() - hash_int = int.from_bytes(d[:4], byteorder='big') - rng = random.fold_in(rng, jnp.uint32(hash_int)) # type: ignore - elif isinstance(x, int): - rng = random.fold_in(rng, x) - else: - raise ValueError(f'Expected int or string, got: {x}') - return rng - - def _fold_in_static( rng: PRNGKey, data: typing.Collection[PRNGFoldable] ) -> PRNGKey: diff --git a/tests/run_all_tests.sh b/tests/run_all_tests.sh index e2ded604d5..920d71017b 100755 --- a/tests/run_all_tests.sh +++ b/tests/run_all_tests.sh @@ -108,7 +108,6 @@ assert_error="flax is not running on editable mode." # env vars must be set after doctest export JAX_NUMPY_RANK_PROMOTION=raise export FLAX_PROFILE=1 -export FLAX_LAZY_RNG=1 if $RUN_PYTEST; then echo "=== RUNNING PYTESTS ==="