From e457e012f07f29041f32e92f35a5a199315ce4e2 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sun, 19 Mar 2023 16:37:59 -0700 Subject: [PATCH] Refactor FunctionWrapper (#66) --- RELEASE.md | 4 +++ setup.py | 2 +- tune/noniterative/convert.py | 70 ++++++++++++------------------------ tune_version/__init__.py | 2 +- 4 files changed, 28 insertions(+), 50 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 943e26f..b635b20 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,5 +1,9 @@ # Release Notes +## 0.1.5 + +- Refactor `FunctionWrapper`, remove the Fugue contraint + ## 0.1.3 - Added Fugue version constraint to avoid breaking changes diff --git a/setup.py b/setup.py index 1c277c0..87bbaac 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ def get_version() -> str: author_email="goodwanghan@gmail.com", keywords="hyper parameter hyperparameter tuning tune tuner optimzation", url="http://github.com/fugue-project/tune", - install_requires=["fugue>=0.7.0,<=0.8.1", "cloudpickle"], + install_requires=["fugue", "cloudpickle", "triad>=0.8.4"], extras_require={ "hyperopt": ["hyperopt"], "optuna": ["optuna"], diff --git a/tune/noniterative/convert.py b/tune/noniterative/convert.py index f4700f3..5dce636 100644 --- a/tune/noniterative/convert.py +++ b/tune/noniterative/convert.py @@ -1,18 +1,18 @@ import copy -import inspect from typing import Any, Callable, Dict, Optional, Tuple, no_type_check -from fugue._utils.interfaceless import ( +from fugue._utils.interfaceless import is_class_method +from triad import assert_or_throw +from triad.collections.function_wrapper import ( + AnnotatedParam, FunctionWrapper, - _FuncParam, - _OtherParam, - is_class_method, + function_wrapper, ) -from triad import assert_or_throw from triad.utils.convert import get_caller_global_local_vars, to_function + +from tune.concepts.flow import Trial, TrialReport from tune.exceptions import TuneCompileError from tune.noniterative.objective import NonIterativeObjectiveFunc -from tune.concepts.flow import Trial, TrialReport def noniterative_objective( @@ -90,66 +90,40 @@ def from_func( return f -class _ReportParam(_FuncParam): - def __init__(self, param: Optional[inspect.Parameter], annotation: Any, code: str): - super().__init__(param, annotation, code) +@function_wrapper(None) +class _NonIterativeObjectiveWrapper(FunctionWrapper): + def __init__(self, func: Callable): + super().__init__(func, ".*", "^[r12]$") + param = self._params.get_value_by_index(0) + self._orig_input = isinstance(param, _TrialParam) + self._orig_output = isinstance(self._rt, _RawReportParam) + +class _ReportParam(AnnotatedParam): def to_report(self, v: Any, trial: Trial) -> TrialReport: raise NotImplementedError # pragma: no cover +@_NonIterativeObjectiveWrapper.annotated_param(TrialReport, "r") class _RawReportParam(_ReportParam): - def __init__(self, param: Optional[inspect.Parameter]): - super().__init__(param, "TrialReport", "r") - def to_report(self, v: Any, trial: Trial) -> TrialReport: return v +@_NonIterativeObjectiveWrapper.annotated_param(float, "1") class _MetricParam(_ReportParam): - def __init__(self, param: Optional[inspect.Parameter]): - super().__init__(param, "float", "1") - def to_report(self, v: Any, trial: Trial) -> TrialReport: return TrialReport(trial, metric=float(v), params=trial.params, metadata={}) +@_NonIterativeObjectiveWrapper.annotated_param(Tuple[float, Dict[str, Any]], "2") class _MetricMetadataParam(_ReportParam): - def __init__(self, param: Optional[inspect.Parameter]): - super().__init__(param, "Tuple[float,Dict[str,Any]]", "2") - def to_report(self, v: Any, trial: Trial) -> TrialReport: return TrialReport( trial, metric=float(v[0]), params=trial.params, metadata=v[1] ) -class _TrialParam(_FuncParam): - def __init__(self, param: Optional[inspect.Parameter]): - super().__init__(param, "Trial", "t") - - -class _NonIterativeObjectiveWrapper(FunctionWrapper): - def __init__(self, func: Callable): - self._orig_input = False - self._orig_output = False - super().__init__(func, "^(t|([x1]+))$", "^[r12]$") - - def _parse_param( - self, - annotation: Any, - param: Optional[inspect.Parameter], - none_as_other: bool = True, - ) -> _FuncParam: - if annotation is float: - return _MetricParam(param) - elif annotation is Tuple[float, Dict[str, Any]]: - return _MetricMetadataParam(param) - elif annotation is TrialReport: - self._orig_output = True - return _RawReportParam(param) - elif annotation is Trial: - self._orig_input = True - return _TrialParam(param) - else: - return _OtherParam(param) +@_NonIterativeObjectiveWrapper.annotated_param(Trial, "t") +class _TrialParam(AnnotatedParam): + pass diff --git a/tune_version/__init__.py b/tune_version/__init__.py index ae73625..1276d02 100644 --- a/tune_version/__init__.py +++ b/tune_version/__init__.py @@ -1 +1 @@ -__version__ = "0.1.3" +__version__ = "0.1.5"