diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e598979..ae4d4143 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - `ensure_same_defaults` decorator for setting one function's defaults as source of truth for other function +- Decorator versions of the various wait functions (`waiting`, `iter_waiting`, `waiting_progress` and `iter_waiting_progress`) ## [0.4.0] - 2019-11-14 diff --git a/easypy/sync.py b/easypy/sync.py index e7f4c610..ac616890 100644 --- a/easypy/sync.py +++ b/easypy/sync.py @@ -8,7 +8,7 @@ import threading import time import inspect -from functools import wraps +from functools import wraps, partial import re import logging import atexit @@ -1080,7 +1080,7 @@ def sleep(): time.sleep(s_timeout) -@wraps(iter_wait) +@wraps(iter_wait, assigned=()) def wait(*args, **kwargs): """ Wait until ``pred`` returns a useful value (see below), or until ``timeout`` passes. @@ -1118,11 +1118,6 @@ def wait(*args, **kwargs): return ret -def wait_progress(*args, **kwargs): - for _ in iter_wait_progress(*args, **kwargs): - pass - - def iter_wait_progress(state_getter, advance_timeout, total_timeout=float("inf"), state_threshold=0, sleep=0.5, throw=True, allow_regression=True, advancer_name=None, progressbar=True): @@ -1173,4 +1168,124 @@ def did_advance(): yield progress # indicate success +@wraps(iter_wait_progress, assigned=()) +def wait_progress(*args, **kwargs): + for _ in iter_wait_progress(*args, **kwargs): + pass + + +def __waiting_decorator(wait_function, func=None, **default_wait_args): + def inner(func): + orig_sig = inspect.signature(func) + wait_sig = inspect.signature(wait_function) + + def generate_parameters(): + yield from orig_sig.parameters.values() + for param in wait_sig.parameters.values(): + if param.name == 'pred': + continue # The predicate is `func` + if param.name == 'message': + continue # This is a function - it should raise a `PredicateNotSatisfied` instead of returning `False` + yield param.replace(kind=param.KEYWORD_ONLY) + + new_sig = orig_sig.replace(parameters=generate_parameters()) + + def invoke_wait_function(args, kwargs): + wait_args = dict(default_wait_args) + for name in list(kwargs.keys()): + if name in kwargs: + wait_args[name] = kwargs.pop(name) + + def pred(): + result = func(*args, **kwargs) + return (result,) # force the value to be truish + + return wait_function(pred=pred, message=False, **wait_args) + + if inspect.isgeneratorfunction(wait_function): + @wraps(func) + def wrapper(*args, **kwargs): + for yielded_value in invoke_wait_function(args, kwargs): + if isinstance(yielded_value, tuple): + yielded_value, = yielded_value + yield yielded_value + else: + @wraps(func) + def wrapper(*args, **kwargs): + result, = invoke_wait_function(args, kwargs) + return result + wrapper.__signature__ = new_sig + return wrapper + + if func: + return inner(func) + else: + return inner + + +def __make_waiting_decorator(wait_function): + def inner(decorator): + print('Decorating', decorator, 'as', wait_function) + wait_sig = inspect.signature(wait_function) + + def wrapper(func=None, **kwargs): + return __waiting_decorator(wait_function, func, **kwargs) + + wrapper.__name__ = decorator.__name__ + + def gen_new_parameters(): + yield inspect.signature(wrapper).parameters['func'] + for parameter in wait_sig.parameters.values(): + if parameter.kind == parameter.POSITIONAL_OR_KEYWORD: + yield parameter.replace(kind=parameter.KEYWORD_ONLY) + else: + yield parameter + wrapper.__signature__ = wait_sig.replace(parameters=gen_new_parameters()) + + wrapper.__doc__ = """ + :py:meth:``{wait_function}`` with the decorated function as its predicate + + All the :py:meth:``{wait_function}`` arguments can be passed as either + keyword arguments to the decorator or keyword arguments to the resulting + function. So:: + + @{decorator_name}(timeout=10) + def foo(a, b, c): + ... + + foo(1, 2, 3, sleep=4) + + Is the same as:: + + def foo(a, b, c): + ... + + {wait_function}(10, lambda: foo(1, 2, 3), sleep=4) + """.format(wait_function=wait_function.__name__, decorator_name=decorator.__name__) + + return wrapper + + return inner + + +@__make_waiting_decorator(wait) +def waiting(): + pass + + +@__make_waiting_decorator(iter_wait) +def iter_waiting(): + pass + + +@__make_waiting_decorator(wait_progress) +def waiting_progress(): + pass + + +@__make_waiting_decorator(iter_wait_progress) +def iter_waiting_progress(): + pass + + from .timing import Timer # noqa; avoid import cycle diff --git a/tests/test_sync.py b/tests/test_sync.py index 4a820701..dc80ac55 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -12,6 +12,7 @@ from easypy.units import Duration from easypy.sync import iter_wait, wait, iter_wait_progress, Timer, TimeoutException, PredicateNotSatisfied +from easypy.sync import waiting, iter_waiting from easypy.sync import SynchronizationCoordinator, SYNC from easypy.sync import shared_contextmanager from easypy.sync import TagAlongThread @@ -798,3 +799,59 @@ def pred(): durations = re.findall('Still waiting after (.*?): bad attempt', get_log()) rounded_durations = [round(Duration(d), 2) for d in durations] assert rounded_durations == [0.2, 0.4], 'expected logs at 200ms and 400ms, got %s' % (durations,) + + +def test_waiting(): + class TimedOut(PredicateNotSatisfied): + pass + + i = 0 + + @waiting + def do_wait(target): + nonlocal i + i += 1 + if i < target: + raise TimedOut(a=1, b=2) + return False + + with pytest.raises(TimedOut): + # due to the short timeout and long sleep, the pred would called exactly twice + do_wait(3, timeout=.1, sleep=1) + + assert i == 2 + assert do_wait(3, timeout=.1) is False + + with pytest.raises(TimedOut): + # due to the short timeout and long sleep, the pred would called exactly twice + do_wait(6, timeout=.1, sleep=1) + + assert i == 5 + do_wait(6, timeout=.1) + + +def test_iter_waiting(): + class TimedOut(PredicateNotSatisfied): + pass + + i = 0 + + @iter_waiting(timeout=0.1, sleep=1) + def do_iter_wait(): + nonlocal i + i += 1 + if i < 3: + raise TimedOut(a=1, b=2) + return i + + with pytest.raises(TimedOut): + for ret in do_iter_wait(): + assert isinstance(ret, Duration) + + for ret in do_iter_wait(): + pass + assert ret == 3 + + i = 0 + for ret in do_iter_wait(timeout=0.2, sleep=0.1): + pass