From 10f8b20d1954f099cb153fa8eed3ed0a5c55d5d5 Mon Sep 17 00:00:00 2001 From: Ryan Rowe Date: Tue, 9 Jun 2020 14:54:11 -0700 Subject: [PATCH] Initial commit --- .github/workflows/publish.yaml | 27 ++ .github/workflows/push.yml | 27 ++ .gitignore | 133 ++++++++ LICENSE | 13 + Pipfile | 10 + README.md | 36 +++ dynamic_dispatch/__init__.py | 97 ++++++ dynamic_dispatch/_class.py | 78 +++++ dynamic_dispatch/_func.py | 153 ++++++++++ dynamic_dispatch/_typeguard.py | 13 + setup.cfg | 39 +++ setup.py | 5 + tests/__init__.py | 0 tests/test_class.py | 543 +++++++++++++++++++++++++++++++++ tests/test_func.py | 356 +++++++++++++++++++++ 15 files changed, 1530 insertions(+) create mode 100644 .github/workflows/publish.yaml create mode 100644 .github/workflows/push.yml create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 Pipfile create mode 100644 README.md create mode 100644 dynamic_dispatch/__init__.py create mode 100644 dynamic_dispatch/_class.py create mode 100644 dynamic_dispatch/_func.py create mode 100644 dynamic_dispatch/_typeguard.py create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 tests/__init__.py create mode 100644 tests/test_class.py create mode 100644 tests/test_func.py diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml new file mode 100644 index 0000000..5706035 --- /dev/null +++ b/.github/workflows/publish.yaml @@ -0,0 +1,27 @@ +name: Publish Python Package + +on: + release: + types: [published] + +jobs: + deploy: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.8' + - name: Install dependencies + run: | + python -m pip install --upgrade pip setuptools pipenv + pipenv install --dev + - name: Build and publish + env: + TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} + TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} + run: | + pipenv run python setup.py sdist + pipenv run twine upload dist/* diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml new file mode 100644 index 0000000..5a7c0bb --- /dev/null +++ b/.github/workflows/push.yml @@ -0,0 +1,27 @@ +name: Push CI + +on: [push, pull_request] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.7, 3.8] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install -U pip pipenv + pipenv install --dev + - name: Lint with flake8 + uses: grantmcconnaughey/lintly-flake8-github-action@v1.0 + if: github.event_name == 'pull_request' + - name: Unit tests + run: | + pipenv run python -m unittest diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f436c8a --- /dev/null +++ b/.gitignore @@ -0,0 +1,133 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# IntelliJ +.DS_Store +.idea diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..07e8059 --- /dev/null +++ b/LICENSE @@ -0,0 +1,13 @@ +Copyright 2020 Xevo Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. \ No newline at end of file diff --git a/Pipfile b/Pipfile new file mode 100644 index 0000000..2467f23 --- /dev/null +++ b/Pipfile @@ -0,0 +1,10 @@ +[[source]] +name = 'pypi' +url = 'https://pypi.org/simple' +verify_ssl = true + +[packages] +dynamic_dispatch = {editable = true, path = '.'} + +[dev-packages] +dynamic_dispatch = {editable = true, path = '.', extras = ['dev']} diff --git a/README.md b/README.md new file mode 100644 index 0000000..6c83868 --- /dev/null +++ b/README.md @@ -0,0 +1,36 @@ +# Dynamic Dispatch + +![Build status](https://img.shields.io/github/workflow/status/XevoInc/dynamic_dispatch/Push%20CI/master) +[![PyPI](https://img.shields.io/pypi/v/dynamic-dispatch)](https://pypi.org/project/dynamic-dispatch/) +![PyPI - License](https://img.shields.io/pypi/l/dynamic-dispatch) + +A lightweight, dynamic dispatch implementation for classes and functions. This allows a class or function to delegate +its implementation conditioned on the value of its first argument. This is similar to `functools.singledispatch`, +however this library dispatches over value while the other dispatches over type. + +## Install + +You may install this via the [`dynamic-dispatch`](https://pypi.org/project/dynamic-dispatch/) package on [PyPi](https://pypi.org): + +```bash +pip3 install dynamic-dispatch +``` + +## Usage + + +## Development + +When developing, it is recommended to use Pipenv. To create your development environment: + +```bash +pipenv install --dev +``` + +### Testing + +This library uses the `unittest` framework. Tests may be run with the following: + +```bash +python3 -m unittest +``` \ No newline at end of file diff --git a/dynamic_dispatch/__init__.py b/dynamic_dispatch/__init__.py new file mode 100644 index 0000000..00052cf --- /dev/null +++ b/dynamic_dispatch/__init__.py @@ -0,0 +1,97 @@ +""" Like functools.singledispatch, but dynamic, value-based dispatch. """ + +__all__ = ('dynamic_dispatch',) + +import functools +import inspect + +from typing import Union, Callable, Type, Hashable + +from dynamic_dispatch._class import class_dispatch +from dynamic_dispatch._func import func_dispatch + +from ._typeguard import typechecked + + +@typechecked(always=True) +def dynamic_dispatch(func: Union[Callable, Type, None] = None, *, default: bool = False): + """ + Value-based dynamic-dispatch class decorator. + + Allows a class or function to have different implementations depending on the + value of func's first parameter. The decorated class or function can act as + the default implementation, if desired. + + Additional implementations may be registered for dispatch using the register() + attribute of the dispatch class or function. If the implementation has a param + of the same name as the first of func, it will be passed along. + + :Example: + + >>> @dynamic_dispatch(default=True) + >>> def foo(bar: int): + >>> print(bar) + >>> + >>> @foo.dispatch(on=5) + >>> def _(bar: int, baz: int): + >>> print(bar * baz) + >>> + >>> @foo.dispatch(on=10) + >>> def _(): + >>> print(-10) + >>> + >>> foo(1) + 1 + >>> foo(5, 10) + 50 + >>> foo(10) + -10 + + :Example: + + >>> @dynamic_dispatch(default=True) + >>> class Foo: + >>> def __init__(self, foo: int): + >>> super().__init__() + >>> print(bar) + >>> + >>> @Foo.dispatch(foo=5) + >>> class Bar(Foo): + >>> def __init__(self, foo, bar): + >>> super().__init__(foo) + >>> print(foo * bar) + >>> + >>> Foo(1) + 1 + <__main__.Foo object at ...> + >>> Foo(5, 10) + 50 + <__main__.Bar object at ...> + + :param func: class or function to add dynamic dispatch to. + :param default: whether or not to use func as the default implementation. + :returns: func with dynamic dispatch + """ + # Default was specified, wait until func is here too. + if func is None: + return functools.partial(dynamic_dispatch, default=default) + + # Delegate depending on wrap type. + if inspect.isclass(func): + return class_dispatch(func, default) + + func = func_dispatch(func, default=default) + + # Alter register to hide implicit parameter. + dispatch = func.dispatch + + def replacement(impl: Callable = None, *, on: Hashable): + if impl is None: + return functools.partial(replacement, on=on) + + return dispatch(impl, arguments=inspect.signature(impl).parameters, on=on) + + # Type checker complains if we assign directly. + setattr(func, 'dispatch', replacement) + + return func diff --git a/dynamic_dispatch/_class.py b/dynamic_dispatch/_class.py new file mode 100644 index 0000000..697406c --- /dev/null +++ b/dynamic_dispatch/_class.py @@ -0,0 +1,78 @@ +""" Like functools.singledispatch, but dynamic, value-based dispatch of classes. """ + +import functools +import inspect +from typing import Hashable, Type, TypeVar, Callable, Union + +from ._typeguard import typechecked + +from ._func import func_dispatch + +T_co = TypeVar('T_co', covariant=True) + + +@typechecked +def class_dispatch(typ: Type[T_co], default: Hashable): + """ + Value-based dynamic-dispatch class decorator. + + Transforms a class into a dynamic dispatch class, which has different + behaviors depending upon the value of its first positional parameter. + The decorated class acts as the default implementation, if default is + specified, and additional classes may be registered using the dispatch() + static function of the dispatch class. + + :param typ: class to add dynamic dispatch to. + :param default: whether or not to default when given an unregistered value. + :returns: dispatch class. + """ + if inspect.isabstract(typ) and default: + raise TypeError('abstract classes cannot be used as a default implementation') + + # Dispatcher must also be class in case anyone wants to use isinstance with it, etc. + @functools.wraps(typ, updated=()) + class Dispatcher(typ): + # Dynamic dispatch on a class is equivalent to dynamic dispatch on __new__. + # Note: the parameters for dispatch here are those of __init__ instead. + @func_dispatch(default=default, clazz=typ) + def __new__(cls, *args, **kwargs): + return super().__new__(cls) + + @classmethod + @typechecked(always=True) + def dispatch(cls, wrap: Union[Type[T_co], Callable[..., T_co]] = None, *, on: Hashable): + if wrap is None: + return functools.partial(cls.dispatch, on=on) + + if not inspect.isclass(wrap): + ret = inspect.signature(wrap).return_annotation + if ret == inspect.Parameter.empty: + raise TypeError(f'function {wrap.__name__} must have annotated return type') from None + + if ret is not None and issubclass(ret, typ): + # It's a function that returns a subtype of the dispatch class, let's allow this. + cls.__new__.dispatch(wrap, arguments=inspect.signature(wrap).parameters, on=on) + return wrap + else: + raise TypeError(f'{wrap.__name__} may not be registered for dispatch on {typ.__name__}' + f'as its return type {ret!r} does not subclass the dispatch type.') + elif not issubclass(wrap, typ): + raise TypeError(f'only subclasses of {typ.__name__} can be registered for dynamic dispatch') + else: + @functools.wraps(wrap, updated=()) + class Registered(wrap): + __dispatch_init = True + + def __init__(self, *args, **kwargs): + # Certain scenarios can cause __init__ to be called twice. This prevents it. + if not self.__dispatch_init: + return + + self.__dispatch_init = False + super().__init__(*args, **kwargs) + + cls.__new__.dispatch(Registered, arguments=inspect.signature(wrap.__init__).parameters, on=on) + + return Registered + + return Dispatcher diff --git a/dynamic_dispatch/_func.py b/dynamic_dispatch/_func.py new file mode 100644 index 0000000..6b554cc --- /dev/null +++ b/dynamic_dispatch/_func.py @@ -0,0 +1,153 @@ +""" Like functools.singledispatch, but dynamic, value-based function dispatch. """ + +import functools +import inspect +from types import MappingProxyType +from typing import Callable, Hashable, Type + +from ._typeguard import typechecked + + +def _lookup(key: str, typ: Type, *args, **kwargs) -> Hashable: + """ + Gets the value of the given key in args, defaulting to the first positional. + + :param key: key to find value of in args. + :param typ: type that dispatch is being perform on. + :param args: positional args. + :param kwargs: keyword args. + :return: value of the key in the given args. + """ + if key in kwargs: + value = kwargs[key] + else: + try: + if typ.__qualname__.endswith('.__new__'): + value = args[1] + else: + value = args[0] + except IndexError: + raise TypeError(f'missing dispatch parameter {key!r} on {typ.__name__}') + + return value + + +@typechecked +def func_dispatch(func: Callable = None, *, default: bool, clazz=None): + """ + Value-based dynamic-dispatch function decorator. + + Transforms a function into a dynamic dispatch function, which has different + behaviors depending upon the value of its first positional parameter. The + decorated function acts as the default implementation, if default is specified, + and additional functions may be registered using the dispatch() attribute of + the dispatch function. + + :param func: function to add dynamic dispatch to. + :param default: whether or not to default when given an unregistered value. + :param clazz: class that func is __new__ for, or None. + :returns: dispatch function. + """ + if func is None: + return functools.partial(func_dispatch, default=default, clazz=clazz) + + if inspect.ismethod(func): + raise NotImplementedError('member functions are not supported') + + if clazz is None: + name = func.__name__ + parameters = inspect.signature(func).parameters + else: + name = clazz.__name__ + parameters = inspect.signature(clazz.__init__).parameters + + registry = {} + + # Find the first explicit (non-splat) positional argument. This is the dispatch parameter. + parameters = iter(parameters.values()) + param = None + while param is None or param.name == 'self' or param.name == 'return' or \ + param.kind == inspect.Parameter.VAR_POSITIONAL or param.kind == inspect.Parameter.VAR_KEYWORD: + try: + param = next(parameters) + except StopIteration: + raise TypeError('dispatch function does not have any explicit positional arguments') from None + key = param.name + + @functools.wraps(func) + def dispatch(*args, **kwargs): + # If dispatching a class, the first argument indicates the type of class desired. + # If that class is not the dispatch class, someone is instantiating a derived class directly. + # In this special case, we bypass dispatch. + if clazz is not None: + if len(args) > 0 and inspect.isclass(args[0]): + klass = args[0] + + if klass is not clazz and not (klass.__qualname__ == clazz.__qualname__ and clazz in klass.__bases__): + if issubclass(klass, clazz): + return func(*args, **kwargs) + else: + raise TypeError(f'cls argument for __new__ must be subclass of {clazz!r}, got {klass!r}') + + # Find dispatch param by position or key. + value = _lookup(key, func, *args, **kwargs) + + if default: + # Allow default to dispatch func, which we know has the dispatch param at index 0. + impl, idx = registry.get(value, (func, 0)) + else: + try: + impl, idx = registry[value] + except KeyError: + raise ValueError(f'no registered implementations for {value!r} for {name}') from None + + if inspect.isclass(impl): + args = args[1:] + + if idx is not None: + idx -= 1 + + if idx is None: + # Dispatch param is not desired, remove it. + if key in kwargs: + del kwargs[key] + else: + # Not in kwargs, must be the first parameter. + args = args[1:] + elif idx > 0 and key not in kwargs: + # Dispatch param is desired and it's not the first argument, so rearrange. + args = args[1:idx + 1] + args[0:1] + args[idx + 1:] + + return impl(*args, **kwargs) + + @typechecked(always=True) + def register(impl: Callable = None, *, arguments: MappingProxyType, on: Hashable): + """ + Registers a new implementation for the given value of key. + + :param on: dispatch value to register this implementation on. + :param arguments: parameters to impl. + :param impl: implementation to associate with value. + """ + if impl is None: + return functools.partial(register, arguments=arguments, on=on) + + if on in registry: + raise ValueError(f'duplicate implementation for {on!r} for {name}') + + # Determine index of dispatch parameter in this signature. + idx = None + for i, parameter in enumerate(arguments.values()): + if parameter.name == key: + if parameter.kind == inspect.Parameter.KEYWORD_ONLY: + # Parameter is keyword-only, so it has no 'index'. + idx = -1 + else: + idx = i + + registry[on] = impl, idx + + return impl + + dispatch.dispatch = register + return dispatch diff --git a/dynamic_dispatch/_typeguard.py b/dynamic_dispatch/_typeguard.py new file mode 100644 index 0000000..68dc9ae --- /dev/null +++ b/dynamic_dispatch/_typeguard.py @@ -0,0 +1,13 @@ +""" Proxy for typeguard, in case it's not installed. """ + +__all__ = ('typechecked',) + +import functools + +try: + from typeguard import typechecked +except ModuleNotFoundError: + def typechecked(func=None, **kwargs): + if func is None: + return functools.partial(typechecked, **kwargs) + return func diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..ec363c3 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,39 @@ +[metadata] +name=dynamic_dispatch +description=dynamic dispatch decorator for classes and functions +version=1.0.0 +url=https://github.com/XevoInc/dynamic_dispatch +long_description=file: README.md, +long_description_content_type=text/markdown +author=Ryan Rowe +author_email=rrowe@xevo.com +license=Apache License, Version 2.0 +license_file=LICENSE +python_requires=>=3.7 +classifiers= + Development Status :: 4 - Beta + Intended Audience :: Developers + License :: OSI Approved :: Apache Software License + Programming Language :: Python + Programming Language :: Python :: 3.7 + Programming Language :: Python :: 3.8 + Topic :: Utilities + Typing :: Typed + +[options] +setup_requires = + # Minimal version with most `setup.cfg` bug fixes. + setuptools >= 38.3.0 +packages = dynamic_dispatch +test_suite = tests + +[options.extras_require] +typeguard = + typeguard >= 2.9.1 +dev = + typeguard >= 2.9.1 + flake8 + twine + +[flake8] +max-line-length = 120 \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..7f317b8 --- /dev/null +++ b/setup.py @@ -0,0 +1,5 @@ +#!/usr/bin/python3 + +from setuptools import setup + +setup() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_class.py b/tests/test_class.py new file mode 100644 index 0000000..129d2ff --- /dev/null +++ b/tests/test_class.py @@ -0,0 +1,543 @@ +from abc import ABC, abstractmethod +from typing import Callable +from unittest import TestCase + +from dynamic_dispatch import dynamic_dispatch + + +class EmptyInit: + def __init__(self): + pass + + +class OneArgInit: + """ Comment """ + + def __init__(self, abc): + self.abc = abc + self.abc_count = getattr(self, 'abc_count', 0) + 1 + + +class TestClassDispatch(TestCase): + def test_returns_class(self): + self.assertIsInstance(dynamic_dispatch(OneArgInit), type) + + def test_wraps(self): + wrapped = dynamic_dispatch(OneArgInit) + + self.assertEqual(OneArgInit.__doc__, wrapped.__doc__) + self.assertEqual(OneArgInit.__name__, wrapped.__name__) + self.assertEqual(OneArgInit.__module__, wrapped.__module__) + + def test_reject_splat_args(self): + class Foo: + def __init__(self, *args): + pass + + with self.assertRaises(TypeError): + dynamic_dispatch(Foo) + + def test_reject_kwargs(self): + class Foo: + def __init__(self, **kwargs): + pass + + with self.assertRaises(TypeError): + dynamic_dispatch(Foo) + + def test_requires_args(self): + with self.assertRaises(TypeError): + dynamic_dispatch(EmptyInit) + with self.assertRaises(TypeError): + dynamic_dispatch(EmptyInit, default=True) + + def test_abstract_class(self): + class Foo(OneArgInit, ABC): + @abstractmethod + def bar(self): + pass + + with self.assertRaises(TypeError): + dynamic_dispatch(Foo, default=True) + + # Abstract classes are allowed if not default. + dynamic_dispatch(Foo) + + def test_decorator(self): + @dynamic_dispatch + class Foo(OneArgInit): + pass + + self.assertIsInstance(Foo, type) + + def test_decorator_with_default(self): + @dynamic_dispatch(default=True) + class Foo(OneArgInit): + pass + + self.assertIsInstance(Foo, type) + + def test_has_register(self): + wrapped = dynamic_dispatch(OneArgInit) + self.assertTrue(hasattr(wrapped, 'dispatch'), 'wrapped class has no dispatch attribute') + self.assertIsInstance(wrapped.dispatch, Callable) + + wrapped = dynamic_dispatch(OneArgInit, default=True) + self.assertTrue(hasattr(wrapped, 'dispatch'), 'wrapped class has no dispatch attribute') + self.assertIsInstance(wrapped.dispatch, Callable) + + def test_register_wraps(self): + wrapped = dynamic_dispatch(OneArgInit) + + class Foo(OneArgInit): + """ Another comment """ + pass + + reg = wrapped.dispatch(Foo, on=1) + + self.assertEqual(Foo.__doc__, reg.__doc__) + self.assertEqual(Foo.__name__, reg.__name__) + self.assertEqual(Foo.__module__, reg.__module__) + + def test_register_reject_splat_args(self): + wrapped = dynamic_dispatch(OneArgInit) + + class Foo: + def __init__(self, *args): + pass + + with self.assertRaises(TypeError): + wrapped.dispatch(Foo, on=1) + + def test_register_reject_kwargs(self): + wrapped = dynamic_dispatch(OneArgInit) + + class Foo: + def __init__(self, **kwargs): + pass + + with self.assertRaises(TypeError): + wrapped.dispatch(Foo, on=1) + + def test_register_non_subclass(self): + wrapped = dynamic_dispatch(OneArgInit) + + class Foo: + def __init__(self, bar): + pass + + with self.assertRaises(TypeError): + wrapped.dispatch(Foo, on=1) + + def test_register_fn(self): + wrapped = dynamic_dispatch(OneArgInit) + + def foo() -> OneArgInit: + pass + + wrapped.dispatch(foo, on=1) + + def test_register_fn_subclass(self): + wrapped = dynamic_dispatch(OneArgInit) + + class Foo(OneArgInit): + pass + + def foo() -> Foo: + pass + + wrapped.dispatch(foo, on=1) + + def test_register_fn_non_class(self): + wrapped = dynamic_dispatch(OneArgInit) + + def foo() -> int: + pass + + with self.assertRaises(TypeError): + wrapped.dispatch(foo, on=1) + + def test_register_fn_no_annotation(self): + wrapped = dynamic_dispatch(OneArgInit) + + with self.assertRaises(TypeError): + wrapped.dispatch(lambda: None, on=1) + + def test_register_not_hashable(self): + wrapped = dynamic_dispatch(OneArgInit) + + with self.assertRaises(TypeError): + wrapped.dispatch(type('foo', (OneArgInit,), {}), on=[]) + + def test_register_no_value(self): + wrapped = dynamic_dispatch(OneArgInit) + + with self.assertRaises(TypeError): + wrapped.dispatch(type('foo', (OneArgInit,), {})) + + def test_register_duplicate_value(self): + wrapped = dynamic_dispatch(OneArgInit) + + wrapped.dispatch(type('foo', (OneArgInit,), {}), on=1) + + with self.assertRaises(ValueError): + wrapped.dispatch(type('bar', (OneArgInit,), {}), on=1) + + def test_register_on_fn(self): + wrapped = dynamic_dispatch(lambda abc: abc) + wrapped.dispatch(type('foo', (OneArgInit,), {}), on=1) + + def test_register_multiple(self): + wrapped = dynamic_dispatch(OneArgInit) + wrapped.dispatch(type('foo', (OneArgInit,), {}), on=1) + wrapped.dispatch(type('bar', (OneArgInit,), {}), on=2) + + def test_register_multi_key(self): + wrapped = dynamic_dispatch(OneArgInit) + impl = type('foo', (OneArgInit,), {}) + + wrapped.dispatch(impl, on=1) + wrapped.dispatch(impl, on=2) + + def test_dispatch_default(self): + wrapped = dynamic_dispatch(OneArgInit, default=True) + + args = (1,) + obj = wrapped(*args) + + self.assertIsInstance(obj, OneArgInit) + self.assertEqual(obj.abc_count, 1) + self.assertEqual(obj.abc, 1) + + def test_decorator_dispatch_default(self): + @dynamic_dispatch(default=True) + class Foo: + def __init__(self, a): + self.a = a + self.a_count = getattr(self, 'a_count', 0) + 1 + + obj = Foo(1) + + self.assertIsInstance(obj, Foo) + self.assertEqual(obj.a, 1) + self.assertEqual(obj.a_count, 1) + + def test_dispatch_default_exc(self): + wrapped = dynamic_dispatch(OneArgInit, default=True) + + with self.assertRaises(TypeError): + wrapped(1, 2) + + def test_dispatch_no_args(self): + wrapped = dynamic_dispatch(OneArgInit, default=True) + + abc = 1 + + class Foo(OneArgInit): + def __init__(self): + super().__init__(abc) + + wrapped.dispatch(Foo, on=abc) + obj = wrapped(abc) + + self.assertIsInstance(obj, OneArgInit) + self.assertEqual(obj.abc_count, 1) + self.assertEqual(obj.abc, abc) + + def test_decorator_dispatch(self): + @dynamic_dispatch(default=True) + class Foo: + def __init__(self, a): + self.a = a + self.a_count = getattr(self, 'a_count', 0) + 1 + + @Foo.dispatch(on=1) + class Bar(Foo): + def __init__(self, b): + super().__init__(1) + self.b = b + self.b_count = getattr(self, 'b_count', 0) + 1 + + args = (1, 2) + obj = Foo(*args) + + self.assertIsInstance(obj, Foo) + self.assertEqual(obj.a, args[0]) + self.assertEqual(obj.a_count, 1) + self.assertEqual(obj.b, args[1]) + self.assertEqual(obj.b_count, 1) + + def test_dispatch_no_args_type_error(self): + wrapped = dynamic_dispatch(OneArgInit, default=True) + + abc = 1 + + class Foo(OneArgInit): + def __init__(self): + super().__init__(abc) + + wrapped.dispatch(Foo, on=abc) + with self.assertRaises(TypeError): + wrapped(abc, 2) + + # noinspection DuplicatedCode + def test_dispatch_one_arg(self): + wrapped = dynamic_dispatch(OneArgInit, default=True) + + abc, bar_ = 1, 2 + + class Foo(OneArgInit): + def __init__(self, bar): + super().__init__(abc) + self.bar = bar + self.bar_count = getattr(self, 'bar_count', 0) + 1 + + wrapped.dispatch(Foo, on=abc) + obj = wrapped(abc, bar_) + + self.assertIsInstance(obj, Foo) + self.assertEqual(obj.abc_count, 1) + self.assertEqual(obj.abc, abc) + self.assertEqual(obj.bar_count, 1) + self.assertEqual(obj.bar, bar_) + + # noinspection DuplicatedCode + def test_dispatch_override_key(self): + wrapped = dynamic_dispatch(OneArgInit, default=True) + + abc_, bar_ = 3, 4 + + class Baz(OneArgInit): + def __init__(self, abc, bar): + super().__init__(abc) + self.bar = bar + self.bar_count = getattr(self, 'bar_count', 0) + 1 + + wrapped.dispatch(Baz, on=abc_) + obj = wrapped(abc_, bar_) + + self.assertIsInstance(obj, Baz) + self.assertEqual(obj.abc_count, 1) + self.assertEqual(obj.abc, abc_) + self.assertEqual(obj.bar_count, 1) + self.assertEqual(obj.bar, bar_) + + # noinspection DuplicatedCode + def test_dispatch_multi_arg(self): + wrapped = dynamic_dispatch(OneArgInit, default=True) + + args = (1, 2, 3, 4) + + class Baz(OneArgInit): + def __init__(self, d, e, f): + super().__init__(args[0]) + self.bar = d, e, f + self.bar_count = getattr(self, 'bar_count', 0) + 1 + + wrapped.dispatch(Baz, on=args[0]) + obj = wrapped(*args) + + self.assertIsInstance(obj, Baz) + self.assertEqual(obj.abc_count, 1) + self.assertEqual(obj.abc, args[0]) + self.assertEqual(obj.bar_count, 1) + self.assertEqual(obj.bar, args[1:]) + + # noinspection DuplicatedCode + def test_dispatch_multi_arg_override_key(self): + wrapped = dynamic_dispatch(OneArgInit, default=True) + + args = (1, 2, 3, 4) + + class Baz(OneArgInit): + def __init__(self, abc, d, e, f): + super().__init__(abc) + self.bar = abc, d, e, f + self.bar_count = getattr(self, 'bar_count', 0) + 1 + + wrapped.dispatch(Baz, on=args[0]) + obj = wrapped(*args) + + self.assertIsInstance(obj, Baz) + self.assertEqual(obj.abc_count, 1) + self.assertEqual(obj.abc, args[0]) + self.assertEqual(obj.bar_count, 1) + self.assertEqual(obj.bar, args) + + # noinspection DuplicatedCode + def test_dispatch_multi_arg_override_key_reorder(self): + wrapped = dynamic_dispatch(OneArgInit, default=True) + + args = (1, 2, 3, 4) + + class Baz(OneArgInit): + def __init__(self, d, e, abc, f): + super().__init__(abc) + self.bar = d, e, abc, f + self.bar_count = getattr(self, 'bar_count', 0) + 1 + + wrapped.dispatch(Baz, on=args[0]) + obj = wrapped(*args) + + self.assertIsInstance(obj, Baz) + self.assertEqual(obj.abc_count, 1) + self.assertEqual(obj.abc, args[0]) + self.assertEqual(obj.bar_count, 1) + self.assertEqual(obj.bar, (*args[1:3], args[0], *args[3:])) + + def test_dispatch_arg_type_error(self): + wrapped = dynamic_dispatch(OneArgInit, default=True) + + args = (1, 2, 3) + + class Baz(OneArgInit): + def __init__(self, d, e, f): + super().__init__(args[0]) + self.bar = d, e, f + self.bar_count = getattr(self, 'bar_count', 0) + 1 + + wrapped.dispatch(Baz, on=args[0]) + + with self.assertRaises(TypeError): + wrapped(*args) + + # noinspection DuplicatedCode + def test_dispatch_kwarg(self): + wrapped = dynamic_dispatch(OneArgInit, default=True) + + abc = 1 + kwargs = dict(f=4, e=4, d=2) + + class Baz(OneArgInit): + def __init__(self, *, d, e, f): + super().__init__(abc) + self.bar = d, e, f + self.bar_count = getattr(self, 'bar_count', 0) + 1 + + wrapped.dispatch(Baz, on=abc) + obj = wrapped(abc, **kwargs) + + self.assertIsInstance(obj, Baz) + self.assertEqual(obj.abc_count, 1) + self.assertEqual(obj.abc, abc) + self.assertEqual(obj.bar_count, 1) + self.assertEqual(obj.bar, tuple(kwargs[key] for key in sorted(kwargs))) + + # noinspection DuplicatedCode + def test_dispatch_kwarg_override_key(self): + wrapped = dynamic_dispatch(OneArgInit, default=True) + + abc = 1 + kwargs = dict(f=4, e=4, d=2) + + class Baz(OneArgInit): + def __init__(self, abc, *, d, e, f): + super().__init__(abc) + self.bar = abc, d, e, f + self.bar_count = getattr(self, 'bar_count', 0) + 1 + + wrapped.dispatch(Baz, on=abc) + obj = wrapped(abc, **kwargs) + + self.assertIsInstance(obj, Baz) + self.assertEqual(obj.abc_count, 1) + self.assertEqual(obj.abc, abc) + self.assertEqual(obj.bar_count, 1) + self.assertEqual(obj.bar, (abc, *tuple(kwargs[key] for key in sorted(kwargs)))) + + def test_dispatch_kwarg_type_error(self): + wrapped = dynamic_dispatch(OneArgInit, default=True) + + abc = 1 + kwargs = dict(f=4, d=2) + + class Baz(OneArgInit): + def __init__(self, *, d, e, f): + super().__init__(abc) + self.bar = d, e, f + self.bar_count = getattr(self, 'bar_count', 0) + 1 + + wrapped.dispatch(Baz, on=abc) + + with self.assertRaises(TypeError): + wrapped(abc, **kwargs) + + def test_dispatch_multi(self): + wrapped = dynamic_dispatch(OneArgInit) + + class Foo(OneArgInit): + def __init__(self, baz): + super().__init__(1) + self.baz = baz + self.baz_count = getattr(self, 'baz_count', 0) + 1 + + class Bar(OneArgInit): + def __init__(self, qux): + super().__init__(2) + self.qux = qux + self.qux_count = getattr(self, 'qux_count', 0) + 1 + + wrapped.dispatch(Foo, on=1) + wrapped.dispatch(Bar, on=2) + + foo = wrapped(1, 5) + + def check_foo(): + self.assertIsInstance(foo, Foo) + self.assertEqual(foo.abc, 1) + self.assertEqual(foo.abc_count, 1) + self.assertEqual(foo.baz, 5) + self.assertEqual(foo.baz_count, 1) + self.assertFalse(hasattr(foo, 'qux')) + check_foo() + + bar = wrapped(2, 10) + self.assertIsInstance(bar, Bar) + self.assertEqual(bar.abc, 2) + self.assertEqual(bar.abc_count, 1) + self.assertEqual(bar.qux, 10) + self.assertEqual(bar.qux_count, 1) + self.assertFalse(hasattr(bar, 'baz')) + + check_foo() + + def test_can_instantiate_registered_impl(self): + @dynamic_dispatch + class Foo(OneArgInit): + pass + + @Foo.dispatch(on=1) + class Bar(Foo): + def __init__(self, a, *, b): + super().__init__(1) + self.a = a + self.b = b + self.a_count = getattr(self, 'a_count', 0) + 1 + self.b_count = getattr(self, 'b_count', 0) + 1 + + obj = Bar(2, b=3) + self.assertIsInstance(obj, Bar) + self.assertEqual(obj.abc, 1) + self.assertEqual(obj.abc_count, 1) + self.assertEqual(obj.a, 2) + self.assertEqual(obj.a_count, 1) + self.assertEqual(obj.b, 3) + self.assertEqual(obj.b_count, 1) + + def test_can_instantiate_unregistered_subclass(self): + @dynamic_dispatch + class Foo(OneArgInit): + pass + + class Bar(Foo): + def __init__(self, a): + super().__init__(1) + self.a = a + self.a_count = getattr(self, 'a_count', 0) + 1 + + obj = Bar(5) + self.assertIsInstance(obj, Bar) + self.assertEqual(obj.abc, 1) + self.assertEqual(obj.abc_count, 1) + self.assertEqual(obj.a, 5) + self.assertEqual(obj.a_count, 1) diff --git a/tests/test_func.py b/tests/test_func.py new file mode 100644 index 0000000..de78993 --- /dev/null +++ b/tests/test_func.py @@ -0,0 +1,356 @@ +from typing import Callable +from unittest import TestCase +from unittest.mock import create_autospec + +from dynamic_dispatch import dynamic_dispatch + + +class TestFuncDispatch(TestCase): + def test_returns_func(self): + self.assertIsInstance(dynamic_dispatch(lambda _: _), Callable) + self.assertIsInstance(dynamic_dispatch(lambda _: _, default=True), Callable) + + def test_wraps(self): + def foo(_): + """ Comment """ + pass + + wrapped = dynamic_dispatch(foo) + + self.assertEqual(foo.__doc__, wrapped.__doc__) + self.assertEqual(foo.__name__, wrapped.__name__) + self.assertEqual(foo.__module__, wrapped.__module__) + + def test_requires_args(self): + with self.assertRaises(TypeError): + dynamic_dispatch(lambda: None) + with self.assertRaises(TypeError): + dynamic_dispatch(lambda: None, default=True) + + def test_decorator(self): + @dynamic_dispatch + def _(_): + pass + + self.assertIsInstance(_, Callable) + + def test_decorator_with_default(self): + @dynamic_dispatch(default=True) + def _(_): + pass + + self.assertIsInstance(_, Callable) + + def test_has_register(self): + wrapped = dynamic_dispatch(lambda _: _) + self.assertTrue(hasattr(wrapped, 'dispatch'), 'wrapped function has no dispatch attribute') + self.assertIsInstance(wrapped.dispatch, Callable) + + wrapped = dynamic_dispatch(lambda _: _, default=True) + self.assertTrue(hasattr(wrapped, 'dispatch'), 'wrapped function has no dispatch attribute') + self.assertIsInstance(wrapped.dispatch, Callable) + + def test_register(self): + wrapped = dynamic_dispatch(lambda _: _) + wrapped.dispatch(lambda _: _, on=1) + + def test_register_wraps(self): + wrapped = dynamic_dispatch(lambda _: _) + + def foo(): + """ Doc comment """ + pass + + reg = wrapped.dispatch(foo, on=1) + + self.assertEqual(foo.__doc__, reg.__doc__) + self.assertEqual(foo.__name__, reg.__name__) + self.assertEqual(foo.__module__, reg.__module__) + + def test_register_not_hashable(self): + wrapped = dynamic_dispatch(lambda _: _) + with self.assertRaises(TypeError): + wrapped.dispatch(lambda _: _, on=[]) + + def test_register_no_value(self): + wrapped = dynamic_dispatch(lambda _: _) + + with self.assertRaises(TypeError): + wrapped.dispatch(lambda _: _) + + def test_register_duplicate_value(self): + wrapped = dynamic_dispatch(lambda _: _) + wrapped.dispatch(lambda _: _, on=1) + + with self.assertRaises(ValueError): + wrapped.dispatch(lambda _: _, on=1) + + def test_register_multiple(self): + wrapped = dynamic_dispatch(lambda _: _) + wrapped.dispatch(lambda _: _, on=1) + wrapped.dispatch(lambda _: _, on=2) + + def test_register_multi_key(self): + wrapped = dynamic_dispatch(lambda _: _) + + def impl(_): + return _ + + wrapped.dispatch(impl, on=1) + wrapped.dispatch(impl, on=2) + + def test_dispatch_default(self): + default = create_autospec(lambda _: _) + wrapped = dynamic_dispatch(default, default=True) + + args = (1,) + wrapped(*args) + + default.assert_called_once_with(*args) + + def test_dispatch_decorator_default(self): + default = create_autospec(lambda _: _) + + @dynamic_dispatch(default=True) + def foo(_): + default(_) + + args = (1,) + foo(*args) + + default.assert_called_once_with(*args) + + def test_dispatch_default_exc(self): + default = create_autospec(lambda _: _) + wrapped = dynamic_dispatch(default) + + with self.assertRaises(ValueError): + wrapped(1) + + default.assert_not_called() + + def test_dispatch_no_args(self): + default = create_autospec(lambda _: _) + wrapped = dynamic_dispatch(default) + + impl = create_autospec(lambda: None) + wrapped.dispatch(impl, on=1) + + wrapped(1) + + default.assert_not_called() + impl.assert_called_once_with() + + def test_dispatch_no_args_type_error(self): + default = create_autospec(lambda _: _) + wrapped = dynamic_dispatch(default) + + impl = create_autospec(lambda: None) + wrapped.dispatch(default, on=1) + + with self.assertRaises(TypeError): + wrapped(1, 2) + + default.assert_not_called() + impl.assert_not_called() + + def test_dispatch_one_arg(self): + default = create_autospec(lambda _: _) + wrapped = dynamic_dispatch(default) + + impl = create_autospec(lambda a: a) + wrapped.dispatch(impl, on=1) + + args = (1, 2) + wrapped(*args) + + default.assert_not_called() + impl.assert_called_once_with(*args[1:]) + + def test_dispatch_decorator(self): + default = create_autospec(lambda _: _) + + @dynamic_dispatch + def foo(_): + default(_) + + impl = create_autospec(lambda a: a) + + @foo.dispatch(on=1) + def bar(a): + impl(a) + + args = (1, 2) + foo(*args) + + default.assert_not_called() + impl.assert_called_once_with(*args[1:]) + + def test_dispatch_override_key(self): + default = create_autospec(lambda _: _) + wrapped = dynamic_dispatch(default) + + impl = create_autospec(lambda _: _) + wrapped.dispatch(impl, on=1) + + args = (1,) + wrapped(*args) + + default.assert_not_called() + impl.assert_called_once_with(*args) + + def test_dispatch_multi_arg(self): + default = create_autospec(lambda _: _) + wrapped = dynamic_dispatch(default) + + impl = create_autospec(lambda a, b, c: None) + wrapped.dispatch(impl, on=1) + + args = (1, 2, 3, 4) + wrapped(*args) + + default.assert_not_called() + impl.assert_called_once_with(*args[1:]) + + def test_dispatch_multi_arg_override_key(self): + default = create_autospec(lambda _: _) + wrapped = dynamic_dispatch(default) + + impl = create_autospec(lambda _, a, b, c: None) + wrapped.dispatch(impl, on=1) + + args = (1, 2, 3, 4) + wrapped(*args) + + default.assert_not_called() + impl.assert_called_once_with(*args) + + # noinspection DuplicatedCode + def test_dispatch_multi_arg_override_key_reorder(self): + default = create_autospec(lambda _: _) + wrapped = dynamic_dispatch(default) + + impl = create_autospec(lambda a, b, _, c: None) + wrapped.dispatch(impl, on=1) + + args = (1, 2, 3, 4) + wrapped(*args) + + default.assert_not_called() + impl.assert_called_once_with(*args[1:3], args[0], *args[3:]) + + def test_dispatch_arg_type_error(self): + default = create_autospec(lambda _: _) + wrapped = dynamic_dispatch(default) + + impl = create_autospec(lambda a, b, c: None) + wrapped.dispatch(impl, on=1) + + args = (1, 2, 3) + with self.assertRaises(TypeError): + wrapped(*args) + + default.assert_not_called() + impl.assert_not_called() + + def test_dispatch_kwarg(self): + default = create_autospec(lambda _: _) + wrapped = dynamic_dispatch(default) + + def impl(*, a, b, c): + return a, b, c + + impl = create_autospec(impl) + wrapped.dispatch(impl, on=1) + + args = (1,) + kwargs = dict(c=4, b=3, a=2) + wrapped(*args, **kwargs) + + default.assert_not_called() + impl.assert_called_once_with(**kwargs) + + def test_dispatch_kwargs_type_error(self): + default = create_autospec(lambda _: _) + wrapped = dynamic_dispatch(default) + + def impl(*, a, b, c): + return a, b, c + + impl = create_autospec(impl) + wrapped.dispatch(impl, on=1) + + args = (1,) + kwargs = dict(a=2, b=3) + with self.assertRaises(TypeError): + wrapped(*args, **kwargs) + + default.assert_not_called() + impl.assert_not_called() + + def test_dispatch_kwarg_override_key(self): + default = create_autospec(lambda _: _) + wrapped = dynamic_dispatch(default) + + def impl(_, *, a, b, c): + return _, a, b, c + + impl = create_autospec(impl) + wrapped.dispatch(impl, on=1) + + args = (1,) + kwargs = dict(c=4, b=3, a=2) + wrapped(*args, **kwargs) + + default.assert_not_called() + impl.assert_called_once_with(*args, **kwargs) + + def test_dispatch_multi(self): + default = create_autospec(lambda _: _) + wrapped = dynamic_dispatch(default) + + a = create_autospec(lambda z: z) + b = create_autospec(lambda z: z) + + wrapped.dispatch(a, on=1) + wrapped.dispatch(b, on=2) + + a.assert_not_called() + b.assert_not_called() + + wrapped(1, -1) + a.assert_called_once_with(-1) + b.assert_not_called() + + wrapped(2, -2) + default.assert_not_called() + a.assert_called_once() + b.assert_called_once_with(-2) + + def test_dispatch_multi_key(self): + default = create_autospec(lambda _: _) + wrapped = dynamic_dispatch(default) + + impl = create_autospec(lambda _, a: _) + + wrapped.dispatch(impl, on=1) + wrapped.dispatch(impl, on=2) + + args = (1, 2) + wrapped(*args) + + impl.assert_called_once_with(*args) + + args = (2, 3) + wrapped(*args) + + default.assert_not_called() + self.assertEqual(impl.call_count, 2) + impl.assert_called_with(*args) + + def test_fails_on_members(self): + with self.assertRaises(TypeError): + class Foo: + @dynamic_dispatch + def member(self): + pass