diff --git a/toolz/curried/__init__.py b/toolz/curried/__init__.py index 356eddbd..2ecaa045 100644 --- a/toolz/curried/__init__.py +++ b/toolz/curried/__init__.py @@ -88,6 +88,7 @@ reduce = toolz.curry(toolz.reduce) reduceby = toolz.curry(toolz.reduceby) remove = toolz.curry(toolz.remove) +reorder_args = toolz.curry(toolz.reorder_args) sliding_window = toolz.curry(toolz.sliding_window) sorted = toolz.curry(toolz.sorted) tail = toolz.curry(toolz.tail) diff --git a/toolz/functoolz.py b/toolz/functoolz.py index 2c75d3a4..09fec19e 100644 --- a/toolz/functoolz.py +++ b/toolz/functoolz.py @@ -1,7 +1,7 @@ -from functools import reduce, partial +from functools import reduce, partial, wraps import inspect import sys -from operator import attrgetter, not_ +from operator import itemgetter, attrgetter, not_ from importlib import import_module from types import MethodType @@ -12,7 +12,7 @@ __all__ = ('identity', 'apply', 'thread_first', 'thread_last', 'memoize', 'compose', 'compose_left', 'pipe', 'complement', 'juxt', 'do', - 'curry', 'flip', 'excepts') + 'curry', 'flip', 'excepts', 'reorder_args') PYPY = hasattr(sys, 'pypy_version_info') @@ -731,6 +731,42 @@ def flip(func, a, b): return func(b, a) +def reorder_args(func, new_args): + """ Returns a new function with a desired argument order. + + >>> def op(a, b, c): + ... return a // (b - c) + ... + >>> new_op = reorder_args(op, ('c', 'a', 'b')) + >>> new_op(1, 2, 3) == op(2, 3, 1) + True + """ + func_sig = inspect.signature(func) + arg_map = [] + parameters = [None] * len(func_sig.parameters) + for i, arg in enumerate(func_sig.parameters.values()): + if (arg.kind in {inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.POSITIONAL_OR_KEYWORD} + and arg.default is inspect._empty): + try: + new_ind = new_args.index(arg.name) + arg_map.append(new_ind) + parameters[new_ind] = arg + except ValueError: + raise ValueError(f"Unable to find positional argument `{arg.name}` in signature for `{func.__name__}`") + else: + parameters[i] = arg + + _mapper = itemgetter(*arg_map) + @wraps(func) + def wrapper(*args, **kwargs): + return func(*_mapper(args), **kwargs) + + wrapper.__signature__ = func_sig.replace(parameters=parameters) + return wrapper + + def return_none(exc): """ Returns None. """ diff --git a/toolz/tests/test_functoolz.py b/toolz/tests/test_functoolz.py index 555cf48d..ce9e50b6 100644 --- a/toolz/tests/test_functoolz.py +++ b/toolz/tests/test_functoolz.py @@ -2,8 +2,8 @@ import toolz from toolz.functoolz import (thread_first, thread_last, memoize, curry, compose, compose_left, pipe, complement, do, juxt, - flip, excepts, apply) -from operator import add, mul, itemgetter + flip, excepts, apply, reorder_args) +from operator import add, mul, getitem, itemgetter from toolz.utils import raises from functools import partial @@ -794,3 +794,17 @@ def raise_(a): excepting = excepts(object(), object(), object()) assert excepting.__name__ == 'excepting' assert excepting.__doc__ == excepts.__doc__ + + +def test_reorder_args(): + def op(a, b, c, d=1): + return a // (b - c) + d + + new_op = reorder_args(op, ('c', 'a', 'b')) + assert new_op(1, 2, 3, d=1) == op(2, 3, 1, d=1) + + # test builtin functions (ie C functions) + getflip = reorder_args(getitem, ('b', 'a')) + get1 = curry(getflip, 1) + assert get1([1, 2, 3, 1, 1]) == 2 + \ No newline at end of file