diff --git a/flax/nnx/object.py b/flax/nnx/object.py index b8f35ba752..0c8c60601f 100644 --- a/flax/nnx/object.py +++ b/flax/nnx/object.py @@ -16,6 +16,7 @@ import dataclasses import inspect +import sys import threading import typing as tp from abc import ABCMeta @@ -157,7 +158,9 @@ def __init_subclass__(cls) -> None: init=cls._graph_node_init, # type: ignore ) - cls.__signature__ = inspect.signature(cls.__init__) + if 'sphinx-build' in sys.argv[0]: + # set correct signature for sphinx + cls.__signature__ = inspect.signature(cls.__init__) if not tp.TYPE_CHECKING: