Skip to content
This repository was archived by the owner on Apr 22, 2020. It is now read-only.

Commit d8b183b

Browse files
committed
Add the mirror_defaults decorator
1 parent a5714b4 commit d8b183b

File tree

3 files changed

+89
-1
lines changed

3 files changed

+89
-1
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
66

77
## [Unreleased]
88

9+
### Added
10+
- `mirror_defaults` decorator for mirroring the default arguments of another
11+
function.
12+
913
## [0.3.0] - 2019-06-10
1014

1115
### Added

easypy/decorations.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from functools import wraps, partial, update_wrapper
66
from operator import attrgetter
77
from abc import ABCMeta, abstractmethod
8+
import inspect
9+
10+
from .tokens import AUTO
811

912

1013
def parametrizeable_decorator(deco):
@@ -138,3 +141,60 @@ def foo(self):
138141
def wrapper(func):
139142
return LazyDecoratorDescriptor(decorator_factory, func, cached)
140143
return wrapper
144+
145+
146+
def mirror_defaults(mirrored):
147+
"""
148+
Copy the default values of arguments from another function.
149+
150+
Set an argument's default to ``AUTO`` to copy the default value from the
151+
mirrored function.
152+
153+
>>> from easypy.decorations import mirror_defaults
154+
155+
>>> def foo(a=1, b=2, c=3):
156+
... print(a, b, c)
157+
158+
>>> @mirror_defaults(foo)
159+
... def bar(a=AUTO, b=4, c=AUTO):
160+
... foo(a, b, c)
161+
162+
>>> bar()
163+
1 4 3
164+
"""
165+
defaults = {
166+
p.name: p.default
167+
for p in inspect.signature(mirrored).parameters.values()
168+
if p.default is not inspect._empty}
169+
170+
def new_params_generator(params, defaults_to_override):
171+
for param in params:
172+
if param.default is AUTO:
173+
try:
174+
default_value = defaults[param.name]
175+
except KeyError:
176+
raise TypeError('%s has no default value for %s' % (mirrored.__name__, param.name))
177+
defaults_to_override.add(param.name)
178+
yield param.replace(default=default_value)
179+
else:
180+
yield param
181+
182+
def outer(func):
183+
orig_signature = inspect.signature(func)
184+
defaults_to_override = set()
185+
new_parameters = new_params_generator(orig_signature.parameters.values(), defaults_to_override)
186+
new_signature = orig_signature.replace(parameters=new_parameters)
187+
188+
@wraps(func)
189+
def inner(*args, **kwargs):
190+
binding = new_signature.bind(*args, **kwargs)
191+
192+
# NOTE: `apply_defaults` was added in Python 3.5, so we cannot use it
193+
for name in defaults_to_override - binding.arguments.keys():
194+
binding.arguments[name] = defaults[name]
195+
196+
return func(*binding.args, **binding.kwargs)
197+
inner.signature = new_signature
198+
199+
return inner
200+
return outer

tests/test_decorations.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
from functools import wraps
44

5-
from easypy.decorations import lazy_decorator
5+
from easypy.decorations import lazy_decorator, mirror_defaults
66
from easypy.misc import kwargs_resilient
7+
from easypy.tokens import AUTO
78

89

910
def test_kwargs_resilient():
@@ -136,3 +137,26 @@ def counter(self):
136137
foo2.ts += 1
137138
assert [foo1.inc(), foo2.inc()] == [2, 2]
138139
assert [foo1.counter, foo2.counter] == [1, 2] # foo1 was not updated since last sync - only foo2
140+
141+
142+
def test_mirror_defaults():
143+
def foo(a, b, c=1, d=2, *args, e=3, f=4, **kwargs):
144+
return locals()
145+
146+
@mirror_defaults(foo)
147+
def bar(a, b=100, c=AUTO, d=20, *args, e=AUTO, f=40, **kwargs):
148+
return foo(a, b, c, d, *args, e=e, f=f, **kwargs)
149+
150+
assert bar(300) == dict(
151+
a=300, b=100,
152+
c=1, d=20,
153+
args=(),
154+
e=3, f=40,
155+
kwargs={})
156+
157+
assert bar(300, 400, 500, 600, 700, e=800, f=900, g=1000) == dict(
158+
a=300, b=400,
159+
c=500, d=600,
160+
args=(700,),
161+
e=800, f=900,
162+
kwargs=dict(g=1000))

0 commit comments

Comments
 (0)