Skip to content

Commit

Permalink
Add option to get the DAG. (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristianZimpelmann authored Oct 6, 2022
1 parent 8794c21 commit ab7f8a2
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 34 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ target/
profile_default/
ipython_config.py

# VS Code
.vscode

# pyenv
.python-version

Expand Down
16 changes: 8 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ repos:
- id: debug-statements
- id: end-of-file-fixer
- repo: https://github.com/asottile/reorder_python_imports
rev: v3.1.0
rev: v3.8.3
hooks:
- id: reorder-python-imports
types: [python]
Expand Down Expand Up @@ -45,12 +45,12 @@ repos:
additional_dependencies: [black==22.3.0]
types: [rst]
- repo: https://github.com/psf/black
rev: 22.3.0
rev: 22.8.0
hooks:
- id: black
types: [python]
- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
rev: 5.0.4
hooks:
- id: flake8
types: [python]
Expand All @@ -71,7 +71,7 @@ repos:
Pygments,
]
- repo: https://github.com/PyCQA/doc8
rev: 0.11.2
rev: v1.0.0
hooks:
- id: doc8
- repo: meta
Expand All @@ -86,11 +86,11 @@ repos:
args: [--no-build-isolation]
additional_dependencies: [setuptools-scm, toml]
- repo: https://github.com/PyCQA/doc8
rev: 0.11.2
rev: v1.0.0
hooks:
- id: doc8
- repo: https://github.com/asottile/setup-cfg-fmt
rev: v1.20.1
rev: v2.0.0
hooks:
- id: setup-cfg-fmt
- repo: https://github.com/econchick/interrogate
Expand All @@ -100,11 +100,11 @@ repos:
args: [-v, --fail-under=20]
exclude: ^(tests|docs|setup\.py)
- repo: https://github.com/codespell-project/codespell
rev: v2.1.0
rev: v2.2.1
hooks:
- id: codespell
- repo: https://github.com/asottile/pyupgrade
rev: v2.34.0
rev: v2.38.2
hooks:
- id: pyupgrade
args: [--py37-plus]
3 changes: 2 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ releases are available on `Anaconda.org
- :gh:`7` improves the examples in the test cases.
- :gh:`10` turns ``targets`` into an optional argument. All variables in the DAG are
returned by default.

- :gh:`9` Add function to return the DAG. Check for cycles in DAG.
(:ghuser:`ChristianZimpelmann`)

0.2.1 - 2022-03-29
------------------
Expand Down
4 changes: 0 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@ classifiers =
Operating System :: POSIX
Programming Language :: Python :: 3
Programming Language :: Python :: 3 :: Only
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Topic :: Utilities

[options]
Expand Down
156 changes: 135 additions & 21 deletions src/dags/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@ def concatenate_functions(
Functions that are not required to produce the targets will simply be ignored.
The arguments of the combined function are all arguments of relevant functions
that are not themselves function names, in alphabetical order.
The arguments of the combined function are all arguments of relevant functions that
are not themselves function names, in alphabetical order.
Args:
functions (dict or list): Dict or list of functions. If a list, the function
name is inferred from the __name__ attribute of the entries. If a dict,
the name of the function is set to the dictionary key.
targets (str | None): Name of the function that produces the target or list of
such function names. If the value is `None`, all variables are returned.
name is inferred from the __name__ attribute of the entries. If a dict, the
name of the function is set to the dictionary key.
targets (str or list or None): Name of the function that produces the target or
list of such function names. If the value is `None`, all variables are
returned.
return_type (str): One of "tuple", "list", "dict". This is ignored if the
targets are a single string or if an aggregator is provided.
aggregator (callable or None): Binary reduction function that is used to
Expand All @@ -45,19 +46,99 @@ def concatenate_functions(
function: A function that produces targets when called with suitable arguments.
"""
_functions = _harmonize_functions(functions)
_targets = _harmonize_targets(targets, list(_functions))
_fail_if_targets_have_wrong_types(_targets)
_fail_if_functions_are_missing(_functions, _targets)

# Create the DAG.
dag = create_dag(functions, targets)

# Build combined function.
out = _create_combined_function_from_dag(
dag, functions, targets, return_type, aggregator, enforce_signature
)

return out


def create_dag(functions, targets):
"""Build a directed acyclic graph (DAG) from functions.
Functions can depend on the output of other functions as inputs, as long as the
dependencies can be described by a directed acyclic graph (DAG).
Functions that are not required to produce the targets will simply be ignored.
Args:
functions (dict or list): Dict or list of functions. If a list, the function
name is inferred from the __name__ attribute of the entries. If a dict, the
name of the function is set to the dictionary key.
targets (str or list or None): Name of the function that produces the target or
list of such function names. If the value is `None`, all variables are
returned.
Returns:
dag: the DAG (as networkx.DiGraph object)
"""
# Harmonize and check arguments.
_functions, _targets = _harmonize_and_check_functions_and_targets(
functions, targets
)

# Create the DAG
_raw_dag = _create_complete_dag(_functions)
_dag = _limit_dag_to_targets_and_their_ancestors(_raw_dag, _targets)
_arglist = _create_arguments_of_concatenated_function(_functions, _dag)
_exec_info = _create_execution_info(_functions, _dag)
dag = _limit_dag_to_targets_and_their_ancestors(_raw_dag, _targets)

# Check if there are cycles in the DAG
_fail_if_dag_contains_cycle(dag)

return dag


def _create_combined_function_from_dag(
dag,
functions,
targets,
return_type="tuple",
aggregator=None,
enforce_signature=True,
):
"""Create combined function which allows to execute a complete directed acyclic
graph (DAG) in one function call.
The arguments of the combined function are all arguments of relevant functions that
are not themselves function names, in alphabetical order.
Args:
dag (networkx.DiGraph): a DAG of functions
functions (dict or list): Dict or list of functions. If a list, the function
name is inferred from the __name__ attribute of the entries. If a dict, the
name of the function is set to the dictionary key.
targets (str or list or None): Name of the function that produces the target or
list of such function names. If the value is `None`, all variables are
returned.
return_type (str): One of "tuple", "list", "dict". This is ignored if the
targets are a single string or if an aggregator is provided.
aggregator (callable or None): Binary reduction function that is used to
aggregate the targets into a single target.
enforce_signature (bool): If True, the signature of the concatenated function
is enforced. Otherwise it is only provided for introspection purposes.
Enforcing the signature has a small runtime overhead.
Returns:
function: A function that produces targets when called with suitable arguments.
"""
# Harmonize and check arguments.
_functions, _targets = _harmonize_and_check_functions_and_targets(
functions, targets
)

_arglist = _create_arguments_of_concatenated_function(_functions, dag)
_exec_info = _create_execution_info(_functions, dag)
_concatenated = _create_concatenated_function(
_exec_info, _arglist, _targets, enforce_signature
)

# Return function in specified format.
if isinstance(targets, str) or (aggregator is not None and len(_targets) == 1):
out = single_output(_concatenated)
elif aggregator is not None:
Expand All @@ -70,7 +151,7 @@ def concatenate_functions(
out = dict_output(_concatenated, keys=_targets)
else:
raise ValueError(
f"Invalid return type {return_type}. Must be 'list', 'tuple', or 'dict'. "
f"Invalid return type {return_type}. Must be 'list', 'tuple', or 'dict'. "
f"You provided {return_type}."
)

Expand All @@ -91,13 +172,14 @@ def get_ancestors(functions, targets, include_targets=False):
set: The ancestors
"""
_functions = _harmonize_functions(functions)
_targets = _harmonize_targets(targets, list(_functions))
_fail_if_targets_have_wrong_types(_targets)
_fail_if_functions_are_missing(_functions, _targets)

raw_dag = _create_complete_dag(_functions)
dag = _limit_dag_to_targets_and_their_ancestors(raw_dag, _targets)
# Harmonize and check arguments.
_functions, _targets = _harmonize_and_check_functions_and_targets(
functions, targets
)

# Create the DAG.
dag = create_dag(functions, targets)

ancestors = set()
for target in _targets:
Expand All @@ -107,6 +189,29 @@ def get_ancestors(functions, targets, include_targets=False):
return ancestors


def _harmonize_and_check_functions_and_targets(functions, targets):
"""Harmonize the type of specified functions and targets and do some checks.
Args:
functions (dict or list): Dict or list of functions. If a list, the function
name is inferred from the __name__ attribute of the entries. If a dict, the
name of the function is set to the dictionary key.
targets (str or list): Name of the function that produces the target or list of
such function names.
Returns:
functions_harmonized: harmonized functions
targets_harmonized: harmonized targets
"""
functions_harmonized = _harmonize_functions(functions)
targets_harmonized = _harmonize_targets(targets, list(functions_harmonized))
_fail_if_targets_have_wrong_types(targets_harmonized)
_fail_if_functions_are_missing(functions_harmonized, targets_harmonized)

return functions_harmonized, targets_harmonized


def _harmonize_functions(functions):
if isinstance(functions, (list, tuple)):
functions = {func.__name__: func for func in functions}
Expand Down Expand Up @@ -141,6 +246,15 @@ def _fail_if_functions_are_missing(functions, targets):
return functions, targets


def _fail_if_dag_contains_cycle(dag):
"""Check for cycles in DAG"""
cycles = list(nx.simple_cycles(dag))

if len(cycles) > 0:
formatted = _format_list_linewise(cycles)
raise ValueError(f"The DAG contains one or more cycles:\n{formatted}")


def _create_complete_dag(functions):
"""Create the complete DAG.
Expand Down Expand Up @@ -275,7 +389,7 @@ def concatenated(*args, **kwargs):


def _format_list_linewise(list_):
formatted_list = '",\n "'.join(list_)
formatted_list = '",\n "'.join([str(c) for c in list_])
return textwrap.dedent(
"""
[
Expand Down
35 changes: 35 additions & 0 deletions tests/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
from dags.dag import concatenate_functions
from dags.dag import create_dag
from dags.dag import get_ancestors


Expand All @@ -22,6 +23,14 @@ def _unrelated(working_hours): # noqa: U100
raise NotImplementedError()


def _leisure_cycle(working_hours, _utility):
return 24 - working_hours + _utility


def _consumption_cycle(working_hours, wage, _utility):
return wage * working_hours + _utility


def _complete_utility(wage, working_hours, leisure_weight):
"""The function that we try to generate dynamically."""
leis = _leisure(working_hours)
Expand Down Expand Up @@ -157,3 +166,29 @@ def g(f, d):

assert list(inspect.signature(concatenated).parameters) == ["c", "d"]
assert concatenated(3, 4) == 10


@pytest.mark.parametrize(
"funcs",
[
{
"_utility": _utility,
"_leisure": _leisure,
"_consumption": _consumption_cycle,
},
{
"_utility": _utility,
"_leisure": _leisure_cycle,
"_consumption": _consumption_cycle,
},
],
)
def test_fail_if_cycle_in_dag(funcs):
with pytest.raises(
ValueError,
match="The DAG contains one or more cycles:",
):
create_dag(
functions=funcs,
targets=["_utility"],
)

0 comments on commit ab7f8a2

Please sign in to comment.