Skip to content

Commit 8826787

Browse files
Implemented transformation-buffer
commit-id:2fbffaa2
1 parent ac28934 commit 8826787

25 files changed

+1760
-19
lines changed

packages/transformation-buffer/pyproject.toml

+20-4
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,28 @@ name = "transformation-buffer"
33
version = "0.1.0"
44
description = "Add your description here"
55
readme = "README.md"
6-
authors = [
7-
{ name = "Jan Smółka", email = "[email protected]" }
8-
]
6+
authors = [{ name = "Jan Smółka", email = "[email protected]" }]
97
requires-python = ">=3.12.7"
10-
dependencies = []
8+
dependencies = [
9+
"jaxtyping>=0.2.37",
10+
"kornia>=0.8.0",
11+
"more-itertools>=10.6.0",
12+
"plum-dispatch>=2.5.7",
13+
"pyserde[numpy]>=0.23.0",
14+
"rustworkx>=0.16.0",
15+
]
16+
17+
[dependency-groups]
18+
dev = [
19+
"hypothesis>=6.125.3",
20+
"pytest>=8.3.4",
21+
"syrupy>=4.8.1",
22+
]
1123

1224
[build-system]
1325
requires = ["hatchling"]
1426
build-backend = "hatchling.build"
27+
28+
[tool.pytest.ini_options]
29+
python_files = "tests/**/*.py"
30+
python_functions = "test_*"
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
1+
from .buffer import Buffer
2+
from .transformation import Transformation
13

4+
__all__ = ['Buffer', 'Transformation']
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
from collections.abc import Callable, Hashable, Iterable
2+
from copy import deepcopy
3+
from functools import reduce
4+
from typing import Any, Self
5+
6+
import serde
7+
from more_itertools import pairwise
8+
from plum import dispatch, overload
9+
from rustworkx import (
10+
PyDiGraph,
11+
digraph_dijkstra_search,
12+
digraph_dijkstra_shortest_paths,
13+
is_connected,
14+
)
15+
from rustworkx.visit import DijkstraVisitor
16+
17+
from transformation_buffer.rigid_model import RigidModel
18+
from transformation_buffer.serialization import graph, tensor
19+
from transformation_buffer.transformation import Transformation
20+
21+
# Initialize custom global class serializers & deserializers.
22+
graph.init()
23+
tensor.init()
24+
25+
# TODO: Implement path compression and verification.
26+
27+
28+
@serde.serde
29+
class Buffer[T: Hashable]:
30+
__frames_of_reference: set[T]
31+
__frame_names_to_node_indices: dict[T, int]
32+
__connections: PyDiGraph[T, Transformation] = serde.field(
33+
serializer=graph.serialize_graph,
34+
deserializer=graph.deserialize_graph,
35+
)
36+
37+
@overload
38+
def __init__(self, frames_of_reference: Iterable[T] | None = None) -> None: # noqa: F811
39+
nodes = (
40+
set() if frames_of_reference is None else deepcopy(set(frames_of_reference))
41+
)
42+
n_nodes = len(nodes)
43+
44+
connections = PyDiGraph[T, Transformation](
45+
multigraph=False, # Think about it later
46+
node_count_hint=n_nodes,
47+
edge_count_hint=n_nodes,
48+
)
49+
50+
node_indices = connections.add_nodes_from(nodes)
51+
52+
self.__frames_of_reference = nodes
53+
self.__frame_names_to_node_indices = dict(zip(nodes, node_indices))
54+
self.__connections = connections
55+
56+
@overload
57+
def __init__( # noqa: F811
58+
self,
59+
frames_of_reference: set[T],
60+
frame_names_to_node_indices: dict[T, int],
61+
connections: PyDiGraph[T, Transformation],
62+
) -> None:
63+
self.__frames_of_reference = frames_of_reference
64+
self.__frame_names_to_node_indices = frame_names_to_node_indices
65+
self.__connections = connections
66+
67+
@dispatch
68+
def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: F811
69+
pass
70+
71+
@property
72+
def connected(self) -> bool:
73+
return is_connected(self.__connections.to_undirected(multigraph=False))
74+
75+
@property
76+
def frames_of_reference(self) -> frozenset[T]:
77+
# Construction of frozenset from set of strings is *somehow* optimized.
78+
# Perhaps there's no need for caching this property.
79+
return frozenset(self.__frames_of_reference)
80+
81+
def add_object(self, model: RigidModel[T]) -> Self:
82+
for from_to, transformation in model.transformations().items():
83+
self[from_to] = transformation
84+
85+
return self
86+
87+
def add_frame_of_reference(self, frame: T) -> Self:
88+
self.__add_frame_of_reference(frame)
89+
return self
90+
91+
def add_transformation(
92+
self,
93+
from_frame: T,
94+
to_frame: T,
95+
transformation: Transformation,
96+
) -> Self:
97+
if self[from_frame, to_frame] is not None:
98+
return self
99+
100+
self[from_frame, to_frame] = transformation
101+
102+
return self
103+
104+
def frames_visible_from(self, frame: T) -> list[T]:
105+
if (frame_index := self.__frame_index(frame)) is None:
106+
return []
107+
108+
class Visitor(DijkstraVisitor):
109+
visible_nodes: list[int] = []
110+
111+
def discover_vertex(self, v: int, score: float) -> None:
112+
self.visible_nodes.append(v)
113+
114+
visitor = Visitor()
115+
digraph_dijkstra_search(self.__connections, [frame_index], visitor=visitor)
116+
117+
visible_node_indices = visitor.visible_nodes
118+
visible_node_indices.remove(frame_index)
119+
120+
return [self.__connections.get_node_data(node) for node in visible_node_indices]
121+
122+
def __add_frame_of_reference(self, frame: T) -> int:
123+
if (index := self.__frame_index(frame)) is not None:
124+
return index
125+
126+
self.__frames_of_reference.add(frame)
127+
index = self.__connections.add_node(frame)
128+
self.__frame_names_to_node_indices[frame] = index
129+
130+
return index
131+
132+
def __frame_index(self, frame: T) -> int | None:
133+
return self.__frame_names_to_node_indices.get(frame, None)
134+
135+
def __setitem__(self, from_to: tuple[T, T], transformation: Transformation) -> None:
136+
from_id, to_id = from_to
137+
138+
from_index = self.__frame_index(from_id) or self.__add_frame_of_reference(from_id)
139+
to_index = self.__frame_index(to_id) or self.__add_frame_of_reference(to_id)
140+
141+
self.__connections.add_edge(from_index, to_index, transformation)
142+
self.__connections.add_edge(to_index, from_index, transformation.inverse())
143+
144+
def __getitem__(self, from_to: tuple[T, T]) -> Transformation | None:
145+
from_id, to_id = from_to
146+
147+
from_index = self.__frame_index(from_id)
148+
to_index = self.__frame_index(to_id)
149+
150+
if from_index is None or to_index is None:
151+
return None
152+
153+
if from_index == to_index:
154+
return Transformation.identity().clone()
155+
156+
connections = self.__connections
157+
158+
if self.__connections.has_edge(from_index, to_index):
159+
return self.__connections.get_edge_data(from_index, to_index)
160+
161+
path_mapping = digraph_dijkstra_shortest_paths(
162+
connections,
163+
from_index,
164+
to_index,
165+
)
166+
167+
if to_index not in path_mapping:
168+
return None
169+
170+
shortest_path = path_mapping[to_index]
171+
172+
return map_reduce( # type: ignore[no-any-return] # MyPy complains about the possible `Any`.
173+
lambda nodes: connections.get_edge_data(*nodes),
174+
lambda t1, t2: t1 @ t2,
175+
pairwise(shortest_path),
176+
)
177+
178+
def __contains__(self, from_to: tuple[T, T]) -> bool:
179+
from_id, to_id = from_to
180+
181+
from_index = self.__frame_index(from_id)
182+
to_index = self.__frame_index(to_id)
183+
184+
if from_index is None or to_index is None:
185+
return False
186+
187+
return self.__connections.has_edge(from_index, to_index)
188+
189+
190+
def map_reduce[U, V](
191+
map_function: Callable[[U], V],
192+
reduce_function: Callable[[V, V], V],
193+
iterable: Iterable[U],
194+
) -> V:
195+
return reduce(reduce_function, map(map_function, iterable))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from math import pi
2+
from typing import Hashable, Protocol, runtime_checkable
3+
4+
import torch
5+
from attrs import frozen
6+
7+
from transformation_buffer.transformation import Transformation
8+
9+
10+
@runtime_checkable
11+
class RigidModel[T: Hashable](Protocol):
12+
def transformations(self) -> dict[tuple[T, T], Transformation]: ...
13+
14+
15+
@frozen
16+
class Cube[T: Hashable]:
17+
width: float
18+
walls: tuple[T, T, T, T, T, T]
19+
20+
# Cube Walls:
21+
# 4
22+
# 5 1 6
23+
# 2
24+
# 3
25+
26+
def transformations(self) -> dict[tuple[T, T], Transformation]:
27+
first, second, third, forth, fifth, sixth = self.walls
28+
29+
half_pi = pi / 2.0
30+
half_width = self.width / 2.0
31+
32+
dtype = torch.float64
33+
34+
down = Transformation.active(
35+
intrinsic_euler_angles=torch.tensor((-half_pi, 0.0, 0.0)),
36+
translation=torch.tensor((0.0, half_width, -half_width)),
37+
).to(dtype=dtype)
38+
39+
left = Transformation.active(
40+
intrinsic_euler_angles=torch.tensor((0.0, half_pi, 0.0)),
41+
translation=torch.tensor((half_width, 0.0, -half_width)),
42+
).to(dtype=dtype)
43+
44+
right = Transformation.active(
45+
intrinsic_euler_angles=torch.tensor((0.0, -half_pi, 0.0)),
46+
translation=torch.tensor((-half_width, 0.0, -half_width)),
47+
).to(dtype=dtype)
48+
49+
return {
50+
(first, second): down.clone(),
51+
(second, third): down.clone(),
52+
(third, forth): down,
53+
(first, fifth): left,
54+
(first, sixth): right,
55+
}

packages/transformation-buffer/src/transformation_buffer/serialization/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from typing import Any
2+
3+
import serde
4+
from plum import dispatch
5+
from rustworkx import PyDiGraph
6+
7+
from transformation_buffer.transformation import Transformation
8+
9+
10+
def init() -> None:
11+
serde.add_serializer(Serializer())
12+
serde.add_deserializer(Deserializer())
13+
14+
15+
class Serializer:
16+
@dispatch
17+
def serialize(self, value: PyDiGraph) -> dict[str, Any]:
18+
return serialize_graph(value)
19+
20+
21+
class Deserializer:
22+
@dispatch
23+
def deserialize(self, cls: type[PyDiGraph], value: Any) -> PyDiGraph:
24+
return deserialize_graph(value)
25+
26+
27+
def serialize_graph(graph: PyDiGraph) -> dict[str, Any]:
28+
nodes = [(index, graph.get_node_data(index)) for index in graph.node_indices()]
29+
edges = [
30+
(from_to, serde.to_dict(graph.get_edge_data(*from_to)))
31+
for from_to in graph.edge_list()
32+
]
33+
34+
return {
35+
'check_cycle': graph.check_cycle,
36+
'multigraph': graph.multigraph,
37+
'nodes': nodes,
38+
'edges': edges,
39+
}
40+
41+
42+
# INVARIANT: Properly deserializes only graphs with Transformations as edge poayloads.
43+
def deserialize_graph(value: Any) -> PyDiGraph:
44+
match value:
45+
case {
46+
'check_cycle': bool(check_cycle),
47+
'multigraph': bool(multigraph),
48+
'nodes': list(nodes),
49+
'edges': list(edges),
50+
}:
51+
graph = PyDiGraph(
52+
check_cycle,
53+
multigraph,
54+
node_count_hint=len(nodes),
55+
edge_count_hint=len(edges),
56+
)
57+
58+
old_node_indices_to_new = {}
59+
60+
for old_index, node in sorted(nodes, key=lambda entry: entry[0]):
61+
new_index = graph.add_node(node)
62+
old_node_indices_to_new[old_index] = new_index
63+
64+
for (old_a, old_b), edge in edges:
65+
new_a = old_node_indices_to_new[old_a]
66+
new_b = old_node_indices_to_new[old_b]
67+
68+
# TODO: Unhack
69+
graph.add_edge(new_a, new_b, serde.from_dict(Transformation, edge))
70+
71+
return graph
72+
73+
case _:
74+
raise serde.SerdeError('')

0 commit comments

Comments
 (0)