Skip to content

Commit

Permalink
Refactor FunctionWrapper (#66)
Browse files Browse the repository at this point in the history
  • Loading branch information
goodwanghan authored Mar 19, 2023
1 parent 797b34e commit e457e01
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 50 deletions.
4 changes: 4 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Release Notes

## 0.1.5

- Refactor `FunctionWrapper`, remove the Fugue contraint

## 0.1.3

- Added Fugue version constraint to avoid breaking changes
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_version() -> str:
author_email="[email protected]",
keywords="hyper parameter hyperparameter tuning tune tuner optimzation",
url="http://github.com/fugue-project/tune",
install_requires=["fugue>=0.7.0,<=0.8.1", "cloudpickle"],
install_requires=["fugue", "cloudpickle", "triad>=0.8.4"],
extras_require={
"hyperopt": ["hyperopt"],
"optuna": ["optuna"],
Expand Down
70 changes: 22 additions & 48 deletions tune/noniterative/convert.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import copy
import inspect
from typing import Any, Callable, Dict, Optional, Tuple, no_type_check

from fugue._utils.interfaceless import (
from fugue._utils.interfaceless import is_class_method
from triad import assert_or_throw
from triad.collections.function_wrapper import (
AnnotatedParam,
FunctionWrapper,
_FuncParam,
_OtherParam,
is_class_method,
function_wrapper,
)
from triad import assert_or_throw
from triad.utils.convert import get_caller_global_local_vars, to_function

from tune.concepts.flow import Trial, TrialReport
from tune.exceptions import TuneCompileError
from tune.noniterative.objective import NonIterativeObjectiveFunc
from tune.concepts.flow import Trial, TrialReport


def noniterative_objective(
Expand Down Expand Up @@ -90,66 +90,40 @@ def from_func(
return f


class _ReportParam(_FuncParam):
def __init__(self, param: Optional[inspect.Parameter], annotation: Any, code: str):
super().__init__(param, annotation, code)
@function_wrapper(None)
class _NonIterativeObjectiveWrapper(FunctionWrapper):
def __init__(self, func: Callable):
super().__init__(func, ".*", "^[r12]$")
param = self._params.get_value_by_index(0)
self._orig_input = isinstance(param, _TrialParam)
self._orig_output = isinstance(self._rt, _RawReportParam)


class _ReportParam(AnnotatedParam):
def to_report(self, v: Any, trial: Trial) -> TrialReport:
raise NotImplementedError # pragma: no cover


@_NonIterativeObjectiveWrapper.annotated_param(TrialReport, "r")
class _RawReportParam(_ReportParam):
def __init__(self, param: Optional[inspect.Parameter]):
super().__init__(param, "TrialReport", "r")

def to_report(self, v: Any, trial: Trial) -> TrialReport:
return v


@_NonIterativeObjectiveWrapper.annotated_param(float, "1")
class _MetricParam(_ReportParam):
def __init__(self, param: Optional[inspect.Parameter]):
super().__init__(param, "float", "1")

def to_report(self, v: Any, trial: Trial) -> TrialReport:
return TrialReport(trial, metric=float(v), params=trial.params, metadata={})


@_NonIterativeObjectiveWrapper.annotated_param(Tuple[float, Dict[str, Any]], "2")
class _MetricMetadataParam(_ReportParam):
def __init__(self, param: Optional[inspect.Parameter]):
super().__init__(param, "Tuple[float,Dict[str,Any]]", "2")

def to_report(self, v: Any, trial: Trial) -> TrialReport:
return TrialReport(
trial, metric=float(v[0]), params=trial.params, metadata=v[1]
)


class _TrialParam(_FuncParam):
def __init__(self, param: Optional[inspect.Parameter]):
super().__init__(param, "Trial", "t")


class _NonIterativeObjectiveWrapper(FunctionWrapper):
def __init__(self, func: Callable):
self._orig_input = False
self._orig_output = False
super().__init__(func, "^(t|([x1]+))$", "^[r12]$")

def _parse_param(
self,
annotation: Any,
param: Optional[inspect.Parameter],
none_as_other: bool = True,
) -> _FuncParam:
if annotation is float:
return _MetricParam(param)
elif annotation is Tuple[float, Dict[str, Any]]:
return _MetricMetadataParam(param)
elif annotation is TrialReport:
self._orig_output = True
return _RawReportParam(param)
elif annotation is Trial:
self._orig_input = True
return _TrialParam(param)
else:
return _OtherParam(param)
@_NonIterativeObjectiveWrapper.annotated_param(Trial, "t")
class _TrialParam(AnnotatedParam):
pass
2 changes: 1 addition & 1 deletion tune_version/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.3"
__version__ = "0.1.5"

0 comments on commit e457e01

Please sign in to comment.