Skip to content

Commit a0e4e48

Browse files
Implemented transformation-buffer
commit-id:2fbffaa2
1 parent 5695ba6 commit a0e4e48

File tree

13 files changed

+785
-19
lines changed

13 files changed

+785
-19
lines changed

packages/transformation-buffer/pyproject.toml

+19-4
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,27 @@ name = "transformation-buffer"
33
version = "0.1.0"
44
description = "Add your description here"
55
readme = "README.md"
6-
authors = [
7-
{ name = "Jan Smółka", email = "[email protected]" }
8-
]
6+
authors = [{ name = "Jan Smółka", email = "[email protected]" }]
97
requires-python = ">=3.12.7"
10-
dependencies = []
8+
dependencies = [
9+
"jaxtyping>=0.2.37",
10+
"kornia>=0.8.0",
11+
"more-itertools>=10.6.0",
12+
"pyserde[numpy]>=0.23.0",
13+
"rustworkx>=0.16.0",
14+
]
15+
16+
[dependency-groups]
17+
dev = [
18+
"hypothesis>=6.125.3",
19+
"pytest>=8.3.4",
20+
"syrupy>=4.8.1",
21+
]
1122

1223
[build-system]
1324
requires = ["hatchling"]
1425
build-backend = "hatchling.build"
26+
27+
[tool.pytest.ini_options]
28+
python_files = "tests/**/*.py"
29+
python_functions = "test_*"
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
1+
from .buffer import Buffer
2+
from .transformation import Transformation
13

4+
__all__ = ['Buffer', 'Transformation']
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
from collections.abc import Callable, Hashable, Iterable
2+
from copy import deepcopy
3+
from functools import reduce
4+
from typing import Self
5+
6+
from more_itertools import pairwise
7+
from rustworkx import PyDiGraph, digraph_dijkstra_shortest_paths, is_connected
8+
9+
from transformation_buffer.rigid_model import RigidModel
10+
from transformation_buffer.transformation import Transformation
11+
12+
# TODO: Implement `pyserde` serialization.
13+
# TODO: Implement path compression and verification.
14+
15+
16+
class Buffer[T: Hashable]:
17+
__frames_of_reference: set[T]
18+
__frame_names_to_node_indices: dict[T, int]
19+
__connections: PyDiGraph[T, Transformation]
20+
21+
def __init__(self, frames_of_reference: Iterable[T] | None = None) -> None:
22+
nodes = (
23+
set() if frames_of_reference is None else deepcopy(set(frames_of_reference))
24+
)
25+
n_nodes = len(nodes)
26+
27+
connections = PyDiGraph[T, Transformation](
28+
multigraph=False, # Think about it later
29+
node_count_hint=n_nodes,
30+
edge_count_hint=n_nodes,
31+
)
32+
33+
node_indices = connections.add_nodes_from(nodes)
34+
35+
self.__frames_of_reference = nodes
36+
self.__frame_names_to_node_indices = dict(zip(nodes, node_indices))
37+
self.__connections = connections
38+
39+
@property
40+
def connected(self) -> bool:
41+
return is_connected(self.__connections.to_undirected(multigraph=False))
42+
43+
@property
44+
def frames_of_reference(self) -> frozenset[T]:
45+
# Construction of frozenset from set of strings is *somehow* optimized.
46+
# Perhaps there's no need for caching this property.
47+
return frozenset(self.__frames_of_reference)
48+
49+
def add_object(self, model: RigidModel[T]) -> Self:
50+
for from_to, transformation in model.transformations().items():
51+
self[from_to] = transformation
52+
53+
return self
54+
55+
def add_frame_of_reference(self, frame: T) -> Self:
56+
self.__add_frame_of_reference(frame)
57+
return self
58+
59+
def add_transformation(
60+
self,
61+
from_frame: T,
62+
to_frame: T,
63+
transformation: Transformation,
64+
) -> Self:
65+
return self
66+
67+
def __add_frame_of_reference(self, frame: T) -> int:
68+
if (index := self.__frame_index(frame)) is not None:
69+
return index
70+
71+
self.__frames_of_reference.add(frame)
72+
index = self.__connections.add_node(frame)
73+
self.__frame_names_to_node_indices[frame] = index
74+
75+
return index
76+
77+
def __frame_index(self, frame: T) -> int | None:
78+
return self.__frame_names_to_node_indices.get(frame, None)
79+
80+
def __setitem__(self, from_to: tuple[T, T], transformation: Transformation) -> None:
81+
from_id, to_id = from_to
82+
83+
from_index = self.__frame_index(from_id) or self.__add_frame_of_reference(from_id)
84+
to_index = self.__frame_index(to_id) or self.__add_frame_of_reference(to_id)
85+
86+
self.__connections.add_edge(from_index, to_index, transformation)
87+
self.__connections.add_edge(to_index, from_index, transformation.inverse())
88+
89+
def __getitem__(self, from_to: tuple[T, T]) -> Transformation | None:
90+
from_id, to_id = from_to
91+
92+
from_index = self.__frame_index(from_id)
93+
to_index = self.__frame_index(to_id)
94+
95+
if from_index is None or to_index is None:
96+
return None
97+
98+
if from_index == to_index:
99+
return Transformation.identity().clone()
100+
101+
connections = self.__connections
102+
103+
if self.__connections.has_edge(from_index, to_index):
104+
return self.__connections.get_edge_data(from_index, to_index)
105+
106+
path_mapping = digraph_dijkstra_shortest_paths(
107+
connections,
108+
from_index,
109+
to_index,
110+
)
111+
112+
if to_index not in path_mapping:
113+
return None
114+
115+
shortest_path = path_mapping[to_index]
116+
117+
return map_reduce( # type: ignore[no-any-return] # MyPy complains about the possible `Any`.
118+
lambda nodes: connections.get_edge_data(*nodes),
119+
lambda t1, t2: t1 @ t2,
120+
pairwise(shortest_path),
121+
)
122+
123+
def __contains__(self, from_to: tuple[T, T]) -> bool:
124+
from_id, to_id = from_to
125+
126+
from_index = self.__frame_index(from_id)
127+
to_index = self.__frame_index(to_id)
128+
129+
if from_index is None or to_index is None:
130+
return False
131+
132+
return self.__connections.has_edge(from_index, to_index)
133+
134+
135+
def map_reduce[U, V](
136+
map_function: Callable[[U], V],
137+
reduce_function: Callable[[V, V], V],
138+
iterable: Iterable[U],
139+
) -> V:
140+
return reduce(reduce_function, map(map_function, iterable))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from math import pi
2+
from typing import Hashable, Protocol
3+
4+
import torch
5+
from attrs import frozen
6+
7+
from transformation_buffer.transformation import Transformation
8+
9+
10+
class RigidModel[T: Hashable](Protocol):
11+
def transformations(self) -> dict[tuple[T, T], Transformation]: ...
12+
13+
14+
@frozen
15+
class Cube[T: Hashable]:
16+
width: float
17+
walls: tuple[T, T, T, T, T, T]
18+
19+
# Cube Walls:
20+
# 4
21+
# 5 1 6
22+
# 2
23+
# 3
24+
25+
def transformations(self) -> dict[tuple[T, T], Transformation]:
26+
first, second, third, forth, fifth, sixth = self.walls
27+
28+
half_pi = pi / 2.0
29+
half_width = self.width / 2.0
30+
31+
down = Transformation.active(
32+
intrinsic_euler_angles=torch.tensor((-half_pi, 0.0, 0.0)),
33+
translation=torch.tensor((0.0, half_width, -half_width)),
34+
)
35+
36+
left = Transformation.active(
37+
intrinsic_euler_angles=torch.tensor((0.0, half_pi, 0.0)),
38+
translation=torch.tensor((half_width, 0.0, -half_width)),
39+
)
40+
41+
right = Transformation.active(
42+
intrinsic_euler_angles=torch.tensor((0.0, -half_pi, 0.0)),
43+
translation=torch.tensor((-half_width, 0.0, -half_width)),
44+
)
45+
46+
return {
47+
(first, second): down.clone(),
48+
(second, third): down.clone(),
49+
(third, forth): down,
50+
(first, fifth): left,
51+
(first, sixth): right,
52+
}

0 commit comments

Comments
 (0)