Skip to content

Commit

Permalink
[nnx] fix fiddle
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae authored and RaghuSpaceRajan committed Jan 24, 2025
1 parent ac91584 commit 99d7229
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs_nnx/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
sys.path.insert(0, os.path.abspath('..'))
# Include local extension.
sys.path.append(os.path.abspath('./_ext'))
# Set environment variable to indicate that we are building the docs.
os.environ['FLAX_DOC_BUILD'] = 'true'

# patch sphinx
# -- Project information -----------------------------------------------------
Expand Down
6 changes: 5 additions & 1 deletion flax/nnx/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import dataclasses
import inspect
import os
import threading
import typing as tp
from abc import ABCMeta
Expand All @@ -38,6 +39,7 @@

G = tp.TypeVar('G', bound='Object')

BUILDING_DOCS = 'FLAX_DOC_BUILD' in os.environ

def _collect_stats(
node: tp.Any, node_stats: dict[int, dict[type[Variable], SizeBytes]]
Expand Down Expand Up @@ -157,7 +159,9 @@ def __init_subclass__(cls) -> None:
init=cls._graph_node_init, # type: ignore
)

cls.__signature__ = inspect.signature(cls.__init__)
if BUILDING_DOCS:
# set correct signature for sphinx
cls.__signature__ = inspect.signature(cls.__init__)

if not tp.TYPE_CHECKING:

Expand Down

0 comments on commit 99d7229

Please sign in to comment.