-
Notifications
You must be signed in to change notification settings - Fork 47
Open
Labels
enhancementNew feature or requestNew feature or request
Description
Please describe the purpose of the feature. Is it related to a problem?
The jaxtyping library is gaining more traction in the Jax ecosystem. It supports dtype and shape annotations for arrays/tensors as follows
from jaxtyping import Float, Array
def __call__(x: Float[Array, " batch channel height width"]) -> Float[Array, " batch features"]: ...and even supports optional runtime type checking. These annotations are super helpful for glancing at an argument and easily reading off its dimensions.
Describe the solution you'd like
I'd love to see incremental adoption of jaxtyping annotations throughout the Stoix codebase. I think this would be good for readability.
Describe alternatives you've considered
As far as I'm aware, there aren't any serious contenders to jaxtyping.
How do we know when implementation of this feature is complete?
Checklist:
- All Array arguments have annotated dtypes and shapes.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request