Skip to content

Commit

Permalink
Fix test_config.py
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Jan 21, 2025
1 parent 33ef921 commit 1bfdbc0
Showing 1 changed file with 9 additions and 24 deletions.
33 changes: 9 additions & 24 deletions src/_gettsim_tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,22 @@
import importlib

import pytest

import _gettsim


def test_default_backend():
from _gettsim.config import numpy_or_jax

assert numpy_or_jax.__name__ == "numpy"
def test_conftest_set_array_backend_updates_use_jax(request):
expected = request.config.option.USE_JAX
from _gettsim.config import USE_JAX

assert expected == USE_JAX

def test_set_backend():
is_jax_installed = importlib.util.find_spec("jax") is not None

# expect default backend
def test_conftest_set_array_backend_updates_backend(request):
use_jax = request.config.option.USE_JAX
expected = "jax.numpy" if use_jax else "numpy"
from _gettsim.config import numpy_or_jax

assert numpy_or_jax.__name__ == "numpy"

if is_jax_installed:
# set jax backend
_gettsim.config.set_array_backend("jax")
from _gettsim.config import numpy_or_jax

assert numpy_or_jax.__name__ == "jax.numpy"

from _gettsim.config import USE_JAX

assert USE_JAX
else:
with pytest.raises(AssertionError):
_gettsim.config.set_array_backend("jax")
got = numpy_or_jax.__name__
assert expected == got


@pytest.mark.parametrize("backend", ["dask", "jax.numpy"])
Expand Down

0 comments on commit 1bfdbc0

Please sign in to comment.