Skip to content

Commit

Permalink
[nnx] add tabulate
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jan 21, 2025
1 parent e4418e2 commit d47f9d2
Show file tree
Hide file tree
Showing 6 changed files with 628 additions and 4 deletions.
1 change: 1 addition & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,5 @@
from .extract import to_tree as to_tree
from .extract import from_tree as from_tree
from .extract import NodeStates as NodeStates
from .summary import tabulate as tabulate
from . import traversals as traversals
2 changes: 1 addition & 1 deletion flax/nnx/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def check_consistent_aliasing(
for path, value in graph.iter_graph(node):
if graph.is_graph_node(value) or isinstance(value, graph.Variable):
if isinstance(value, Object):
value.check_valid_context(
value._check_valid_context(
lambda: f'Trying to extract graph node from different trace level, got {value!r}'
)
if isinstance(value, graph.Variable):
Expand Down
8 changes: 6 additions & 2 deletions flax/nnx/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ class Array(reprlib.Representable):
shape: tp.Tuple[int, ...]
dtype: tp.Any

@staticmethod
def from_array(array: jax.Array | np.ndarray) -> Array:
return Array(array.shape, array.dtype)

def __nnx_repr__(self):
yield reprlib.Object(type='Array', same_line=True)
yield reprlib.Attr('shape', self.shape)
Expand Down Expand Up @@ -163,12 +167,12 @@ def __setattr__(self, name: str, value: Any) -> None:
self._setattr(name, value)

def _setattr(self, name: str, value: tp.Any) -> None:
self.check_valid_context(
self._check_valid_context(
lambda: f"Cannot mutate '{type(self).__name__}' from different trace level"
)
object.__setattr__(self, name, value)

def check_valid_context(self, error_msg: tp.Callable[[], str]) -> None:
def _check_valid_context(self, error_msg: tp.Callable[[], str]) -> None:
if not self._object__state.trace_state.is_valid():
raise errors.TraceContextError(error_msg())

Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/rnglib.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __post_init__(self):
raise TypeError(f'key must be a jax.Array, got {type(self.key)}')

def __call__(self) -> jax.Array:
self.check_valid_context(
self._check_valid_context(
lambda: 'Cannot call RngStream from a different trace level'
)
key = jax.random.fold_in(self.key.value, self.count.value)
Expand Down
Loading

0 comments on commit d47f9d2

Please sign in to comment.