diff --git a/arrayfire/array_api/__init__.py b/arrayfire/array_api/__init__.py index 798a886..0a21bf0 100644 --- a/arrayfire/array_api/__init__.py +++ b/arrayfire/array_api/__init__.py @@ -4,9 +4,9 @@ __all__ = ["__array_api_version__"] -from ._constants import Device +from ._constants import Device, e, inf, nan, newaxis, pi -__all__ += ["Device"] +__all__ += ["Device", "e", "inf", "nan", "pi", "newaxis"] from ._creation_function import ( arange, diff --git a/arrayfire/array_api/_constants.py b/arrayfire/array_api/_constants.py index a578d13..ebbe05c 100644 --- a/arrayfire/array_api/_constants.py +++ b/arrayfire/array_api/_constants.py @@ -8,6 +8,7 @@ from __future__ import annotations +import math from dataclasses import dataclass import arrayfire as af @@ -105,3 +106,10 @@ def __post_init__(self) -> None: class SupportsDLPack(Protocol): def __dlpack__(self, /, *, stream: None = ...) -> PyCapsule: ... + + +e = math.e +inf = math.inf +nan = math.nan +pi = math.pi +newaxis = None