Skip to content

Commit fc6be98

Browse files
Implemented transformation-buffer
commit-id:2fbffaa2
1 parent 4eb1c4b commit fc6be98

File tree

12 files changed

+672
-19
lines changed

12 files changed

+672
-19
lines changed

packages/transformation-buffer/pyproject.toml

+19-4
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,27 @@ name = "transformation-buffer"
33
version = "0.1.0"
44
description = "Add your description here"
55
readme = "README.md"
6-
authors = [
7-
{ name = "Jan Smółka", email = "[email protected]" }
8-
]
6+
authors = [{ name = "Jan Smółka", email = "[email protected]" }]
97
requires-python = ">=3.12.7"
10-
dependencies = []
8+
dependencies = [
9+
"jaxtyping>=0.2.37",
10+
"kornia>=0.8.0",
11+
"more-itertools>=10.6.0",
12+
"pyserde[numpy]>=0.23.0",
13+
"rustworkx>=0.16.0",
14+
]
15+
16+
[dependency-groups]
17+
dev = [
18+
"hypothesis>=6.125.3",
19+
"pytest>=8.3.4",
20+
"syrupy>=4.8.1",
21+
]
1122

1223
[build-system]
1324
requires = ["hatchling"]
1425
build-backend = "hatchling.build"
26+
27+
[tool.pytest.ini_options]
28+
python_files = "tests/**/*.py"
29+
python_functions = "test_*"
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
1+
from .buffer import Buffer
2+
from .transformation import Transformation
13

4+
__all__ = ['Buffer', 'Transformation']
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from collections.abc import Callable, Hashable, Iterable
2+
from copy import deepcopy
3+
from functools import reduce
4+
from typing import Self
5+
6+
from more_itertools import pairwise
7+
from rustworkx import PyDiGraph, digraph_dijkstra_shortest_paths, is_connected
8+
9+
from transformation_buffer.rigid_model import RigidModel
10+
from transformation_buffer.transformation import Transformation
11+
12+
# TODO: Implement `pyserde` serialization.
13+
# TODO: Implement path compression and verification.
14+
15+
16+
class Buffer[T: Hashable]:
17+
__frames_of_reference: set[T]
18+
__frame_names_to_node_indices: dict[T, int]
19+
__connections: PyDiGraph[T, Transformation]
20+
21+
def __init__(self, frames_of_reference: Iterable[T] | None = None) -> None:
22+
nodes = (
23+
set() if frames_of_reference is None else deepcopy(set(frames_of_reference))
24+
)
25+
n_nodes = len(nodes)
26+
27+
connections = PyDiGraph[T, Transformation](
28+
multigraph=True, # Think about it later
29+
node_count_hint=n_nodes,
30+
edge_count_hint=n_nodes,
31+
)
32+
33+
node_indices = connections.add_nodes_from(nodes)
34+
35+
self.__frames_of_reference = nodes
36+
self.__frame_names_to_node_indices = dict(zip(nodes, node_indices))
37+
self.__connections = connections
38+
39+
@property
40+
def connected(self) -> bool:
41+
return is_connected(self.__connections.to_undirected(multigraph=False))
42+
43+
@property
44+
def frames_of_reference(self) -> frozenset[T]:
45+
# Construction of frozenset from set of strings is *somehow* optimized.
46+
# Perhaps there's no need for caching this property.
47+
return frozenset(self.__frames_of_reference)
48+
49+
def add_object(self, model: RigidModel[T]) -> Self:
50+
for from_to, transformation in model.transformations().items():
51+
self[from_to] = transformation
52+
53+
return self
54+
55+
def add_frame_of_reference(self, frame: T) -> Self:
56+
self.__add_frame_of_reference(frame)
57+
return self
58+
59+
def __add_frame_of_reference(self, frame: T) -> int:
60+
if (index := self.__frame_index(frame)) is not None:
61+
return index
62+
63+
self.__frames_of_reference.add(frame)
64+
index = self.__connections.add_node(frame)
65+
self.__frame_names_to_node_indices[frame] = index
66+
67+
return index
68+
69+
def __frame_index(self, frame: T) -> int | None:
70+
return self.__frame_names_to_node_indices.get(frame, None)
71+
72+
def __setitem__(self, from_to: tuple[T, T], transformation: Transformation) -> None:
73+
from_id, to_id = from_to
74+
75+
from_index = self.__frame_index(from_id) or self.__add_frame_of_reference(from_id)
76+
to_index = self.__frame_index(to_id) or self.__add_frame_of_reference(to_id)
77+
78+
self.__connections.add_edge(from_index, to_index, transformation)
79+
self.__connections.add_edge(to_index, from_index, transformation.inverse())
80+
81+
def __getitem__(self, from_to: tuple[T, T]) -> Transformation | None:
82+
from_id, to_id = from_to
83+
84+
from_index = self.__frame_index(from_id)
85+
to_index = self.__frame_index(to_id)
86+
87+
if from_index is None or to_index is None:
88+
return None
89+
90+
if from_index == to_index:
91+
return Transformation.identity().clone()
92+
93+
connections = self.__connections
94+
95+
if self.__connections.has_edge(from_index, to_index):
96+
return self.__connections.get_edge_data(from_index, to_index)
97+
98+
path_mapping = digraph_dijkstra_shortest_paths(
99+
connections,
100+
from_index,
101+
to_index,
102+
)
103+
104+
if to_index not in path_mapping:
105+
return None
106+
107+
shortest_path = path_mapping[to_index]
108+
109+
return map_reduce( # type: ignore[no-any-return] # MyPy complains about the possible `Any`.
110+
lambda nodes: connections.get_edge_data(*nodes),
111+
lambda t1, t2: t1 @ t2,
112+
pairwise(shortest_path),
113+
)
114+
115+
def __contains__(self, from_to: tuple[T, T]) -> bool:
116+
from_id, to_id = from_to
117+
118+
from_index = self.__frame_index(from_id)
119+
to_index = self.__frame_index(to_id)
120+
121+
if from_index is None or to_index is None:
122+
return False
123+
124+
return self.__connections.has_edge(from_index, to_index)
125+
126+
127+
def map_reduce[U, V](
128+
map_function: Callable[[U], V],
129+
reduce_function: Callable[[V, V], V],
130+
iterable: Iterable[U],
131+
) -> V:
132+
return reduce(reduce_function, map(map_function, iterable))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from math import pi
2+
from typing import Hashable, Protocol
3+
4+
import torch
5+
from attrs import frozen
6+
7+
from transformation_buffer.transformation import Transformation
8+
9+
10+
class RigidModel[T: Hashable](Protocol):
11+
def transformations(self) -> dict[tuple[T, T], Transformation]: ...
12+
13+
14+
@frozen
15+
class Cube[T: Hashable]:
16+
width: float
17+
walls: tuple[T, T, T, T, T, T]
18+
19+
# Cube Walls:
20+
# 4
21+
# 5 1 6
22+
# 2
23+
# 3
24+
25+
def transformations(self) -> dict[tuple[T, T], Transformation]:
26+
first, second, third, forth, fifth, sixth = self.walls
27+
28+
half_pi = pi / 2.0
29+
half_width = self.width / 2.0
30+
31+
down = Transformation.active(
32+
intrinsic_euler_angles=torch.tensor((-half_pi, 0.0, 0.0)),
33+
translation=torch.tensor((0.0, half_width, -half_width)),
34+
)
35+
36+
left = Transformation.active(
37+
intrinsic_euler_angles=torch.tensor((0.0, half_pi, 0.0)),
38+
translation=torch.tensor((half_width, 0.0, -half_width)),
39+
)
40+
41+
right = Transformation.active(
42+
intrinsic_euler_angles=torch.tensor((0.0, -half_pi, 0.0)),
43+
translation=torch.tensor((-half_width, 0.0, -half_width)),
44+
)
45+
46+
return {
47+
(first, second): down.clone(),
48+
(second, third): down.clone(),
49+
(third, forth): down,
50+
(first, fifth): left,
51+
(first, sixth): right,
52+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
from dataclasses import dataclass
2+
from typing import Literal, Self
3+
4+
import torch
5+
from jaxtyping import Float
6+
from serde import field, serde
7+
8+
9+
def _serialize_tensor(tensor: torch.Tensor) -> str:
10+
return f'torch.tensor({tensor.tolist()}, dtype={tensor.dtype})'
11+
12+
13+
def _deserialize_tensor(input: str) -> torch.Tensor:
14+
import torch as __torch
15+
16+
result = eval(input, {}, {'torch': __torch})
17+
del __torch
18+
19+
return result # type: ignore[no-any-return]
20+
21+
22+
type Axis = Literal['x', 'y', 'z']
23+
24+
25+
def rotation_around(
26+
axis: Axis,
27+
angle: float | torch.Tensor,
28+
) -> Float[torch.Tensor, '3 3']:
29+
a = torch.tensor(angle)
30+
sin = torch.sin(a)
31+
cos = torch.cos(a)
32+
33+
match axis:
34+
case 'x':
35+
return torch.tensor(
36+
[
37+
[1.0, 0.0, 0.0],
38+
[0.0, cos, -sin],
39+
[0.0, sin, cos],
40+
]
41+
)
42+
43+
case 'y':
44+
return torch.tensor(
45+
[
46+
[cos, 0.0, sin],
47+
[0.0, 1.0, 0.0],
48+
[-sin, 0.0, cos],
49+
]
50+
)
51+
52+
case 'z':
53+
return torch.tensor(
54+
[
55+
[cos, -sin, 0.0],
56+
[sin, cos, 0.0],
57+
[0.0, 0.0, 1.0],
58+
]
59+
)
60+
61+
62+
def translation_along(axis: Axis, distance: float) -> Float[torch.Tensor, '3']:
63+
t = torch.zeros(3)
64+
65+
match axis:
66+
case 'x':
67+
t[0] = distance
68+
case 'y':
69+
t[1] = distance
70+
case 'z':
71+
t[2] = distance
72+
73+
return t
74+
75+
76+
@serde
77+
@dataclass(eq=False)
78+
class Transformation:
79+
rotation_and_translation: Float[torch.Tensor, '4 4'] = field(
80+
serializer=_serialize_tensor,
81+
deserializer=_deserialize_tensor,
82+
)
83+
84+
def __init__(
85+
self,
86+
rotation_and_translation: Float[torch.Tensor, '4 4'],
87+
) -> None:
88+
self.rotation_and_translation = rotation_and_translation.clone()
89+
90+
def __eq__(self, other: object, /) -> bool:
91+
return isinstance(other, Transformation) and bool(
92+
self.rotation_and_translation.eq(other.rotation_and_translation).all()
93+
)
94+
95+
def __matmul__(self, other: 'Transformation') -> 'Transformation':
96+
return Transformation(
97+
other.rotation_and_translation @ self.rotation_and_translation
98+
)
99+
100+
def inverse(self) -> 'Transformation':
101+
return Transformation(self.rotation_and_translation.inverse())
102+
103+
@staticmethod
104+
def approx_eq(
105+
t1: 'Transformation',
106+
t2: 'Transformation',
107+
absolute_tolerance: float = 1e-8,
108+
) -> bool:
109+
return bool(
110+
torch.isclose(
111+
t1.rotation_and_translation,
112+
t2.rotation_and_translation,
113+
atol=absolute_tolerance,
114+
).all()
115+
)
116+
117+
# Cannot write `IDENTITY: ClassVar[Self] = Transformation.from_parts`.
118+
# Class properties are also deprecated.
119+
# Python... :v
120+
@classmethod
121+
def identity(cls) -> Self:
122+
return cls.from_parts(
123+
torch.eye(3, 3, dtype=torch.float32),
124+
torch.zeros(3, dtype=torch.float32),
125+
)
126+
127+
@classmethod
128+
def from_parts(
129+
cls,
130+
rotation: Float[torch.Tensor, '3 3'],
131+
translation: Float[torch.Tensor, '3'],
132+
) -> Self:
133+
rotation_and_translation = torch.zeros(
134+
(4, 4),
135+
dtype=rotation.dtype,
136+
device=rotation.device,
137+
)
138+
rotation_and_translation[:3, :3] = rotation
139+
rotation_and_translation[:3, 3] = translation
140+
rotation_and_translation[3, 3] = 1.0
141+
142+
return cls(rotation_and_translation)
143+
144+
@classmethod
145+
def active(
146+
cls,
147+
intrinsic_euler_angles: Float[torch.Tensor, '3'] | None = None,
148+
translation: Float[torch.Tensor, '3'] | None = None,
149+
) -> Self:
150+
rotation = (
151+
(
152+
rotation_around('x', intrinsic_euler_angles[0])
153+
@ rotation_around('y', intrinsic_euler_angles[1])
154+
@ rotation_around('z', intrinsic_euler_angles[2])
155+
)
156+
if intrinsic_euler_angles is not None
157+
else torch.eye(3)
158+
)
159+
160+
translation = translation if translation is not None else torch.zeros(3)
161+
162+
return cls.from_parts(rotation, translation)
163+
164+
def clone(self) -> 'Transformation':
165+
return Transformation(self.rotation_and_translation.clone())
166+
167+
def cpu(self) -> 'Transformation':
168+
return Transformation(self.rotation_and_translation.cpu())
169+
170+
def to(
171+
self,
172+
device: torch.device | None = None,
173+
dtype: torch.dtype | None = None,
174+
) -> 'Transformation':
175+
return Transformation(self.rotation_and_translation.to(device, dtype, copy=True))

packages/transformation-buffer/tests/__init__.py

Whitespace-only changes.

packages/transformation-buffer/tests/buffer/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)