diff --git a/dag_example_module.png b/dag_example_module.png index 5351fb719..6bacf4647 100644 Binary files a/dag_example_module.png and b/dag_example_module.png differ diff --git a/hamilton/async_driver.py b/hamilton/async_driver.py index 63d3a8bf6..1c79ce6e3 100644 --- a/hamilton/async_driver.py +++ b/hamilton/async_driver.py @@ -5,7 +5,7 @@ import time import typing import uuid -from types import ModuleType +from types import FunctionType, ModuleType from typing import Any, Dict, Optional, Tuple import hamilton.lifecycle.base as lifecycle_base @@ -199,6 +199,7 @@ def __init__( result_builder: Optional[base.ResultMixin] = None, adapters: typing.List[lifecycle.LifecycleAdapter] = None, allow_module_overrides: bool = False, + functions: typing.List[FunctionType] = None, ): """Instantiates an asynchronous driver. @@ -249,6 +250,7 @@ def __init__( *async_adapters, # note async adapters will not be called during synchronous execution -- this is for access later ], allow_module_overrides=allow_module_overrides, + functions=functions, ) self.initialized = False diff --git a/hamilton/driver.py b/hamilton/driver.py index 868764e58..0646673dc 100644 --- a/hamilton/driver.py +++ b/hamilton/driver.py @@ -13,7 +13,7 @@ import typing import uuid from datetime import datetime -from types import ModuleType +from types import FunctionType, ModuleType from typing import ( Any, Callable, @@ -402,6 +402,7 @@ def __init__( self, config: Dict[str, Any], *modules: ModuleType, + functions: List[FunctionType] = None, adapter: Optional[ Union[lifecycle_base.LifecycleAdapter, List[lifecycle_base.LifecycleAdapter]] ] = None, @@ -435,13 +436,15 @@ def __init__( if adapter.does_hook("pre_do_anything", is_async=False): adapter.call_all_lifecycle_hooks_sync("pre_do_anything") error = None + self.graph_functions = functions if functions is not None else [] self.graph_modules = modules try: - self.graph = graph.FunctionGraph.from_modules( - *modules, + self.graph = graph.FunctionGraph.compile( + modules=list(modules), + functions=functions if functions is not None else [], config=config, adapter=adapter, - allow_module_overrides=allow_module_overrides, + allow_node_overrides=allow_module_overrides, ) if _materializers: materializer_factories, extractor_factories = self._process_materializers( @@ -1866,6 +1869,7 @@ def __init__(self): # common fields self.config = {} self.modules = [] + self.functions = [] self.materializers = [] # Allow later modules to override nodes of the same name @@ -1927,6 +1931,17 @@ def with_modules(self, *modules: ModuleType) -> "Builder": self.modules.extend(modules) return self + def with_functions(self, *functions: FunctionType) -> "Builder": + """Adds the specified functions to the list. + This can be called multiple times. If you have allow_module_overrides + set this will enabl overwriting modules or previously added functions. + + :param functions: + :return: self + """ + self.functions.extend(functions) + return self + def with_adapter(self, adapter: base.HamiltonGraphAdapter) -> "Builder": """Sets the adapter to use. @@ -2168,6 +2183,7 @@ def build(self) -> Driver: _graph_executor=graph_executor, _use_legacy_adapter=False, allow_module_overrides=self._allow_module_overrides, + functions=self.functions, ) def copy(self) -> "Builder": diff --git a/hamilton/graph.py b/hamilton/graph.py index 43ccd24ca..68ae94d5f 100644 --- a/hamilton/graph.py +++ b/hamilton/graph.py @@ -13,7 +13,7 @@ import pathlib import uuid from enum import Enum -from types import ModuleType +from types import FunctionType, ModuleType from typing import Any, Callable, Collection, Dict, FrozenSet, List, Optional, Set, Tuple, Type import hamilton.lifecycle.base as lifecycle_base @@ -142,17 +142,18 @@ def update_dependencies( return nodes -def create_function_graph( +def compile_to_nodes( *functions: List[Tuple[str, Callable]], config: Dict[str, Any], adapter: lifecycle_base.LifecycleAdapterSet = None, fg: Optional["FunctionGraph"] = None, - allow_module_overrides: bool = False, + allow_node_level_overrides: bool = False, ) -> Dict[str, node.Node]: """Creates a graph of all available functions & their dependencies. :param modules: A set of modules over which one wants to compute the function graph :param config: Dictionary that we will inspect to get values from in building the function graph. :param adapter: The adapter that adapts our node type checking based on the context. + :param allow_node_level_overrides: Whether or not to allow node names to override each other :return: list of nodes in the graph. If it needs to be more complicated, we'll return an actual networkx graph and get all the rest of the logic for free """ @@ -170,7 +171,7 @@ def create_function_graph( for n in fm_base.resolve_nodes(f, config): if n.name in config: continue # This makes sure we overwrite things if they're in the config... - if n.name in nodes and not allow_module_overrides: + if n.name in nodes and not allow_node_level_overrides: raise ValueError( f"Cannot define function {n.name} more than once." f" Already defined by function {f}" @@ -713,13 +714,42 @@ def __init__( self.nodes = nodes self.adapter = adapter + @staticmethod + def compile( + modules: List[ModuleType], + functions: List[FunctionType], + config: Dict[str, Any], + adapter: lifecycle_base.LifecycleAdapterSet = None, + allow_node_overrides: bool = False, + ) -> "FunctionGraph": + """Base level static function for compiling a function graph. Note + that this can both use functions (E.G. passing them directly) and modules + (passing them in and crawling. + + :param modules: Modules to use + :param functions: Functions to use + :param config: Config to use for setting up the DAG + :param adapter: Adapter to use for node resolution + :param allow_node_overrides: Whether or not to allow node level overrides. + :return: The compiled function graph + """ + module_functions = sum([find_functions(module) for module in modules], []) + nodes = compile_to_nodes( + *module_functions, + *functions, + config=config, + adapter=adapter, + allow_node_level_overrides=allow_node_overrides, + ) + return FunctionGraph(nodes, config, adapter) + @staticmethod def from_modules( *modules: ModuleType, config: Dict[str, Any], adapter: lifecycle_base.LifecycleAdapterSet = None, allow_module_overrides: bool = False, - ): + ) -> "FunctionGraph": """Initializes a function graph from the specified modules. Note that this was the old way we constructed FunctionGraph -- this is not a public-facing API, so we replaced it with a constructor that takes in nodes directly. If you hacked in something using @@ -732,28 +762,28 @@ def from_modules( :return: a function graph. """ - functions = sum([find_functions(module) for module in modules], []) - return FunctionGraph.from_functions( - *functions, + return FunctionGraph.compile( + modules=modules, + functions=[], config=config, adapter=adapter, - allow_module_overrides=allow_module_overrides, + allow_node_overrides=allow_module_overrides, ) @staticmethod def from_functions( - *functions, + *functions: FunctionType, config: Dict[str, Any], adapter: lifecycle_base.LifecycleAdapterSet = None, allow_module_overrides: bool = False, ) -> "FunctionGraph": - nodes = create_function_graph( - *functions, + return FunctionGraph.compile( + modules=[], + functions=functions, config=config, adapter=adapter, - allow_module_overrides=allow_module_overrides, + allow_node_overrides=allow_module_overrides, ) - return FunctionGraph(nodes, config, adapter) def with_nodes(self, nodes: Dict[str, Node]) -> "FunctionGraph": """Creates a new function graph with the additional specified nodes. diff --git a/pyproject.toml b/pyproject.toml index 24f8da0e9..e7945e15d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,8 +57,7 @@ docs = [ "diskcache", # required for all the plugins "dlt", - # furo -- install from main for now until the next release is out: - "furo @ git+https://github.com/pradyunsg/furo@main", + "furo", "gitpython", # Required for parsing git info for generation of data-adapter docs "grpcio-status", "lightgbm",