diff --git a/docs_nnx/conf.py b/docs_nnx/conf.py index 7eee470630..6bfd970472 100644 --- a/docs_nnx/conf.py +++ b/docs_nnx/conf.py @@ -35,6 +35,8 @@ sys.path.insert(0, os.path.abspath('..')) # Include local extension. sys.path.append(os.path.abspath('./_ext')) +# Set environment variable to indicate that we are building the docs. +os.environ['FLAX_DOC_BUILD'] = 'true' # patch sphinx # -- Project information ----------------------------------------------------- diff --git a/flax/nnx/object.py b/flax/nnx/object.py index b8f35ba752..68b718f7f7 100644 --- a/flax/nnx/object.py +++ b/flax/nnx/object.py @@ -16,6 +16,7 @@ import dataclasses import inspect +import os import threading import typing as tp from abc import ABCMeta @@ -157,7 +158,9 @@ def __init_subclass__(cls) -> None: init=cls._graph_node_init, # type: ignore ) - cls.__signature__ = inspect.signature(cls.__init__) + if 'FLAX_DOC_BUILD' in os.environ: + # set correct signature for sphinx + cls.__signature__ = inspect.signature(cls.__init__) if not tp.TYPE_CHECKING: