Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions packages/ordeq/src/ordeq/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from itertools import chain
from typing import Generic, TypeVar, cast

from ordeq._io import AnyIO, IOIdentity, _is_io
from ordeq._io import AnyIO, _is_io
from ordeq._nodes import Node, _is_view
from ordeq._resource import Resource

Expand Down Expand Up @@ -181,29 +181,25 @@ def __repr__(self) -> str:

# TODO: remove entire class
@dataclass(frozen=True)
class NodeIOGraph(Graph[IOIdentity | Node]):
edges: dict[IOIdentity | Node, list[IOIdentity | Node]]
ios: dict[IOIdentity, AnyIO]
class NodeIOGraph(Graph[AnyIO | Node]):
edges: dict[AnyIO | Node, list[AnyIO | Node]]
ios: dict[AnyIO, AnyIO]

@classmethod
def from_nodes(cls, nodes: Sequence[Node]) -> Self:
return cls.from_graph(NodeGraph.from_nodes(nodes))

@classmethod
def from_graph(cls, base: NodeGraph) -> Self:
edges: dict[IOIdentity | Node, list[IOIdentity | Node]] = defaultdict(
list
)
ios: dict[IOIdentity, AnyIO] = {}
edges: dict[AnyIO | Node, list[AnyIO | Node]] = defaultdict(list)
ios: dict[AnyIO, AnyIO] = {}
for node in base.topological_ordering:
for input_ in node.inputs:
input_id = id(input_)
ios[input_id] = input_
edges[input_id].append(node)
ios[input_] = input_
edges[input_].append(node)
for output in node.outputs:
output_id = id(output)
ios[output_id] = output
edges[node].append(output_id)
ios[output] = output
edges[node].append(output)
return cls(edges=edges, ios=ios)

@cached_property
Expand All @@ -214,7 +210,7 @@ def __repr__(self) -> str:
# Hacky way to generate a deterministic repr of this class.
# This should move to a separate named graph class.
lines: list[str] = []
names: dict[IOIdentity | Node, str] = {
names: dict[AnyIO | Node, str] = {
**{node: f"{node.type_name}:{node.ref}" for node in self.nodes},
**{
io: f"io-{i}"
Expand Down
5 changes: 0 additions & 5 deletions packages/ordeq/src/ordeq/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,11 +747,6 @@ def __repr__(self):
# Type aliases
AnyIO: TypeAlias = Input[T] | Output[T]

# Type alias for IO identity retrieved using id(). This is used to uniquely
# identify IO instances. We cannot rely on the __eq__ and __hash__ of IO
# objects, as they may be overridden by the user.
IOIdentity: TypeAlias = Annotated[int, "Identity of an IO object"]


def _is_input(obj: object) -> TypeGuard[Input]:
return isinstance(obj, Input)
Expand Down
11 changes: 5 additions & 6 deletions packages/ordeq/src/ordeq/_process_ios.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
def _assign_io_fqns(*nodes: Node, io_fqns: IOFQNs) -> None:
for node in nodes:
for io in chain(node.inputs, node.outputs):
io_id = id(io)
if io_id in io_fqns:
if len(io_fqns[io_id]) == 1:
io._set_fqn(io_fqns[io_id][0]) # type: ignore[attr-defined]
elif len(io_fqns[io_id]) > 1:
io._set_name(io_fqns[io_id][0].name) # type: ignore[attr-defined]
if io in io_fqns:
if len(io_fqns[io]) == 1: # type: ignore[index]
io._set_fqn(io_fqns[io][0]) # type: ignore[index]
elif len(io_fqns[io]) > 1: # type: ignore[index]
io._set_name(io_fqns[io][0].name) # type: ignore[index]


def _process_ios(
Expand Down
14 changes: 6 additions & 8 deletions packages/ordeq/src/ordeq/_resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from ordeq._fqn import FQN, ModuleName, ObjectRef, is_object_ref
from ordeq._hook import NodeHook, RunHook, RunnerHook
from ordeq._io import AnyIO, IOIdentity, _is_io, _is_io_sequence
from ordeq._io import AnyIO, _is_io, _is_io_sequence
from ordeq._nodes import Node, _is_node

RunnableRef: TypeAlias = ObjectRef | ModuleName
Expand Down Expand Up @@ -164,19 +164,17 @@ def _resolve_refs_to_modules(


def _resolve_module_to_ios(module: ModuleType) -> dict[str, AnyIO]:
ios: dict[IOIdentity, tuple[AnyIO, str]] = {}
ios: dict[AnyIO, str] = {}
for name, obj in vars(module).items():
if _is_io(obj):
io_id = id(obj)
# TODO: Should also resolve to IO sequence
if io_id in ios:
alias = ios[io_id][1]
if obj in ios:
raise ValueError(
f"Module '{module.__name__}' contains duplicate keys "
f"for the same IO ('{name}' and '{alias}')"
f"for the same IO ('{name}' and '{ios[obj]}')"
)
ios[io_id] = (obj, name)
return {name: io for io, name in ios.values()}
ios[obj] = name
return {name: io for io, name in ios.items()}


def _resolve_package_to_ios(package: ModuleType) -> Catalog:
Expand Down
11 changes: 5 additions & 6 deletions packages/ordeq/src/ordeq/_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from typing import TypeAlias

from ordeq._fqn import FQN
from ordeq._io import IOIdentity, _is_io
from ordeq._io import AnyIO, _is_io
from ordeq._nodes import Node, _is_node

NodeFQNs: TypeAlias = dict[Node, list[FQN]]
IOFQNs: TypeAlias = dict[IOIdentity, list[FQN]]
IOFQNs: TypeAlias = dict[AnyIO, list[FQN]]


def _scan_fqns(*modules: ModuleType) -> tuple[NodeFQNs, IOFQNs]:
Expand All @@ -16,16 +16,15 @@ def _scan_fqns(*modules: ModuleType) -> tuple[NodeFQNs, IOFQNs]:
for module in modules:
for name, obj in vars(module).items():
if _is_io(obj):
io_id = id(obj)
if io_id in io_fqns:
existing_fqn = io_fqns[io_id][0]
if obj in io_fqns:
existing_fqn = io_fqns[obj][0]
if name != existing_fqn.name:
raise ValueError(
f"Module '{module.__name__}' aliases IO "
f"'{existing_fqn.ref}' to '{name}'. "
f"IOs cannot be aliased."
)
io_fqns[io_id].append(FQN(module.__name__, name))
io_fqns[obj].append(FQN(module.__name__, name))
elif _is_node(obj):
if obj in node_fqns:
existing = node_fqns[obj][0]
Expand Down