Skip to content

Commit

Permalink
[nnx] fix UpdateContextManager
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 636934066
  • Loading branch information
Cristian Garcia authored and Flax Authors committed May 24, 2024
1 parent 9b90e98 commit b1cb952
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions flax/experimental/nnx/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,21 +990,23 @@ def merge(
@dataclasses.dataclass
class UpdateContextManager:
tag: str
ctx: UpdateContext | None

def __enter__(self):
self.ctx = UpdateContext(self.tag, None, None)
GRAPH_CONTEXT.update_context_stacks[self.tag].append(self.ctx)
return self.ctx
ctx = UpdateContext(self.tag, None, None)
GRAPH_CONTEXT.update_context_stacks[self.tag].append(ctx)
return ctx

def __exit__(self, *args):
if self.ctx is None:
raise RuntimeError('ctx should not be None, this is a bug.')
stack = GRAPH_CONTEXT.update_context_stacks[self.tag]
if not stack:
raise RuntimeError(
f'No update context found for tag {self.tag!r}, this is a bug.'
)

GRAPH_CONTEXT.update_context_stacks[self.tag].pop()
self.ctx.refmap = None
self.ctx.idxmap = None
self.ctx = None
ctx = GRAPH_CONTEXT.update_context_stacks[self.tag].pop()
# clear references
ctx.refmap = None
ctx.idxmap = None

def __call__(self, f: F) -> F:
@functools.wraps(f)
Expand Down Expand Up @@ -1107,7 +1109,7 @@ def update_context(tag: str):
Args:
tag: A string tag to identify the context.
"""
return UpdateContextManager(tag, None)
return UpdateContextManager(tag)


def current_update_context(tag: str) -> UpdateContext:
Expand Down

0 comments on commit b1cb952

Please sign in to comment.