diff --git a/flax/__init__.py b/flax/__init__.py index ddadc2ad..ee1af501 100644 --- a/flax/__init__.py +++ b/flax/__init__.py @@ -24,6 +24,7 @@ from . import jax_utils from . import linen from . import serialization +from . import struct from . import traverse_util # DO NOT REMOVE - Marker for internal deprecated API.