Skip to content

Commit

Permalink
Actually use JAX to vectorize function
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Jan 21, 2025
1 parent 1bfdbc0 commit 6d19b9d
Showing 1 changed file with 4 additions and 13 deletions.
17 changes: 4 additions & 13 deletions src/_gettsim/policy_environment_postprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import inspect
from typing import TYPE_CHECKING

import numpy

from _gettsim.aggregation import (
all_by_p_id,
any_by_p_id,
Expand All @@ -25,6 +23,7 @@
from _gettsim.config import (
SUPPORTED_GROUPINGS,
TYPES_INPUT_VARIABLES,
USE_JAX,
)
from _gettsim.functions.derived_function import DerivedFunction
from _gettsim.functions.policy_function import PolicyFunction
Expand All @@ -35,6 +34,7 @@
remove_group_suffix,
)
from _gettsim.time_conversion import create_time_conversion_functions
from _gettsim.vectorization import make_vectorizable

Check warning on line 37 in src/_gettsim/policy_environment_postprocessor.py

View check run for this annotation

Codecov / codecov/patch

src/_gettsim/policy_environment_postprocessor.py#L37

Added line #L37 was not covered by tests

if TYPE_CHECKING:
from collections.abc import Callable
Expand Down Expand Up @@ -588,17 +588,8 @@ def _vectorize_func(func):
if isinstance(func, PolicyFunction):
return func

# What should work once that Jax backend is fully supported
signature = inspect.signature(func)
func_vec = numpy.vectorize(func)

@functools.wraps(func)
def wrapper_vectorize_func(*args, **kwargs):
return func_vec(*args, **kwargs)

wrapper_vectorize_func.__signature__ = signature

return wrapper_vectorize_func
backend = "jax" if USE_JAX else "numpy"
return make_vectorizable(func, backend=backend)


def _fail_if_targets_are_not_among_functions(functions, targets):
Expand Down

0 comments on commit 6d19b9d

Please sign in to comment.