diff --git a/docs_nnx/conf.py b/docs_nnx/conf.py index 7eee4706..6bfd9704 100644 --- a/docs_nnx/conf.py +++ b/docs_nnx/conf.py @@ -35,6 +35,8 @@ sys.path.insert(0, os.path.abspath('..')) # Include local extension. sys.path.append(os.path.abspath('./_ext')) +# Set environment variable to indicate that we are building the docs. +os.environ['FLAX_DOC_BUILD'] = 'true' # patch sphinx # -- Project information ----------------------------------------------------- diff --git a/flax/nnx/object.py b/flax/nnx/object.py index b8f35ba7..e88786b7 100644 --- a/flax/nnx/object.py +++ b/flax/nnx/object.py @@ -16,6 +16,7 @@ import dataclasses import inspect +import os import threading import typing as tp from abc import ABCMeta @@ -38,6 +39,7 @@ G = tp.TypeVar('G', bound='Object') +BUILDING_DOCS = 'FLAX_DOC_BUILD' in os.environ def _collect_stats( node: tp.Any, node_stats: dict[int, dict[type[Variable], SizeBytes]] @@ -157,7 +159,9 @@ def __init_subclass__(cls) -> None: init=cls._graph_node_init, # type: ignore ) - cls.__signature__ = inspect.signature(cls.__init__) + if BUILDING_DOCS: + # set correct signature for sphinx + cls.__signature__ = inspect.signature(cls.__init__) if not tp.TYPE_CHECKING: