From 7bddd53bb5e9afe3ffa2f1d7dae3cf67e6a6f365 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Sat, 28 Oct 2023 09:31:17 +0000 Subject: [PATCH] expose Rng --- flax/experimental/nnx/__init__.py | 1 + flax/experimental/nnx/nnx/variables.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/flax/experimental/nnx/__init__.py b/flax/experimental/nnx/__init__.py index 0d8478483c..f28b43383b 100644 --- a/flax/experimental/nnx/__init__.py +++ b/flax/experimental/nnx/__init__.py @@ -92,6 +92,7 @@ from .nnx.variables import Empty as Empty from .nnx.variables import Intermediate as Intermediate from .nnx.variables import Param as Param +from .nnx.variables import Rng as Rng from .nnx.variables import Variable as Variable from .nnx.variables import VariableMetadata as VariableMetadata from .nnx.variables import with_metadata as with_metadata diff --git a/flax/experimental/nnx/nnx/variables.py b/flax/experimental/nnx/nnx/variables.py index 20bbc26ff1..a9d8c998c2 100644 --- a/flax/experimental/nnx/nnx/variables.py +++ b/flax/experimental/nnx/nnx/variables.py @@ -391,6 +391,8 @@ class Intermediate(Variable[A]): class Rng(Variable[jax.Array]): + tag: str + def __init__( self, value: jax.Array,