22
33from abc import ABC , abstractmethod
44from dataclasses import dataclass , field
5- from typing import Optional , Tuple
5+ from typing import Tuple
66
77import torch
88
@@ -44,7 +44,7 @@ def __post_init__(self):
4444 assert self .B .shape == (self .g_dim , self .u_dim ), f"dynamics_step: B must be ({ self .g_dim } , { self .u_dim } ), got { self .B .shape } "
4545 object .__setattr__ (self , "B_T" , self .B .T )
4646
47- def project_to_motion_constraints (self , g : torch .Tensor , xi : torch .Tensor ) -> torch .Tensor :
47+ def project_to_motion_constraints (self , _g : torch .Tensor , xi : torch .Tensor ) -> torch .Tensor :
4848 """Project the state to the motion constraints."""
4949 return xi
5050
@@ -189,8 +189,9 @@ def right_minus(cls, g_start: torch.Tensor, g_end: torch.Tensor) -> torch.Tensor
189189 def right_invariant_error (cls , estimated_state : torch .Tensor , true_state : torch .Tensor ) -> torch .Tensor :
190190 """Computes the right invariant error between the estimated state and
191191 the true state."""
192- assert estimated_state .shape == true_state .shape , f"right_invariant_error: mismatched shapes { estimated_state .shape } vs { true_state .shape } "
193- return estimated_state @ cls .inverse (true_state )
192+ # assert estimated_state.shape == true_state.shape, f"right_invariant_error: mismatched shapes {estimated_state.shape} vs {true_state.shape}"
193+ # return estimated_state @ cls.inverse(true_state)
194+ return true_state @ cls .inverse (estimated_state )
194195
195196 @classmethod
196197 def left_invariant_error (cls , estimated_state : torch .Tensor , true_state : torch .Tensor ) -> torch .Tensor :
@@ -199,7 +200,7 @@ def left_invariant_error(cls, estimated_state: torch.Tensor, true_state: torch.T
199200 assert estimated_state .shape == true_state .shape , f"right_invariant_error: mismatched shapes { estimated_state .shape } vs { true_state .shape } "
200201 return cls .inverse (true_state ) @ estimated_state
201202
202- def f (self , g : torch .Tensor , xi : torch .Tensor , u : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
203+ def f (self , _g : torch .Tensor , xi : torch .Tensor , u : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
203204 """Update using Euler-Poincare equations."""
204205 D = self .g_dim
205206 assert xi .ndim == 2 and xi .shape [1 ] == D , f"xi must be (N, { D } ), got { xi .shape } "
@@ -232,36 +233,44 @@ def update_configuration(self, g: torch.Tensor, xi: torch.Tensor, dt: float) ->
232233
233234 def update_velocity (self , xi : torch .Tensor , dxi : torch .Tensor , dt : float ) -> torch .Tensor :
234235 """Updates the velocity (Lie algebra element xi) using the Lie algebra
235- element dxi.
236- """
236+ element dxi."""
237237 assert xi .ndim == 2 and xi .shape [- 1 ] == self .g_dim , f"update_velocity: xi must be (N, { self .g_dim } ), got { xi .shape } "
238238 assert dxi .ndim == 2 and dxi .shape [- 1 ] == self .g_dim , f"update_velocity: dxi must be (N, { self .g_dim } ), got { dxi .shape } "
239239 return xi + dxi * dt
240240
241+ @staticmethod
241242 @abstractmethod
242243 def map_q_to_configuration (q : torch .Tensor ) -> torch .Tensor :
243244 """Map the configuration vector to the Lie group element."""
244245 raise NotImplementedError
245246
247+ @staticmethod
246248 @abstractmethod
247249 def map_configuration_to_q (g : torch .Tensor ) -> torch .Tensor :
248250 """Map the Lie Group element to configuration space."""
249251 raise NotImplementedError
250252
253+ @staticmethod
251254 @abstractmethod
252255 def map_dq_to_velocity (q : torch .Tensor , dq : torch .Tensor ) -> torch .Tensor :
253- """ Map the velocity in configuration space to the Lie Algebra velocity."""
256+ """Map the velocity in configuration space to the Lie Algebra
257+ velocity."""
254258 raise NotImplementedError
255259
260+ @staticmethod
256261 @abstractmethod
257262 def map_velocity_to_dq (q : torch .Tensor , velocity : torch .Tensor ) -> torch .Tensor :
258- """ Map the velocity in Lie Algebra to the configuration space velocity."""
263+ """Map the velocity in Lie Algebra to the configuration space
264+ velocity."""
259265 raise NotImplementedError
260266
267+
268+ @dataclass (frozen = True )
261269class NonholonomicGroup (MatrixLieGroup ):
262270 """Base class for nonholonomic matrix Lie groups."""
263271
264- constraint_projection_matrix : Optional [torch .Tensor ] = field (init = False , repr = False ) # Enforces A @ xi = 0 constraint
272+ constraint_projection_matrix_velocity : torch .Tensor = field (init = False , repr = False ) # Enforces A @ xi = 0 constraint
273+ constraint_projection_matrix_wrench : torch .Tensor = field (init = False , repr = False ) # Enforces
265274
266275 def __init__ (self , * args , ** kwargs ):
267276 super ().__init__ (* args , ** kwargs )
@@ -276,25 +285,29 @@ def __init__(self, *args, **kwargs):
276285 P = self .inertia_matrix_inv @ A_matrix .T @ lambda_solver @ A_matrix # inertia_matrix^-1 @ A^T @ (A @ inertia_matrix^-1 @ A^T)^-1 @ A
277286 I_minus_P = torch .eye (self .g_dim , device = self .inertia_matrix .device ) - P
278287
279- object .__setattr__ (self , "constraint_projection_matrix" , I_minus_P .T )
288+ PI = (
289+ torch .eye (self .g_dim , device = self .inertia_matrix .device )
290+ - A_matrix .T @ torch .linalg .inv (A_matrix @ self .inertia_matrix_inv @ A_matrix .T ) @ A_matrix @ self .inertia_matrix_inv
291+ )
292+
293+ object .__setattr__ (self , "constraint_projection_matrix_velocity" , I_minus_P ) # Storing transpose for batch multiplication purposes
294+ object .__setattr__ (self , "constraint_projection_matrix_wrench" , PI )
280295
281- def project_to_motion_constraints (self , g : torch .Tensor , xi : torch .Tensor ) -> torch .Tensor :
296+ def project_to_motion_constraints (self , _g : torch .Tensor , xi : torch .Tensor ) -> torch .Tensor :
282297 """Project the state to the motion constraints."""
283- return xi @ self .constraint_projection_matrix
298+ return xi @ self .constraint_projection_matrix_velocity . T
284299
285300 @abstractmethod
286- def get_Pfaffian_A (self , g :torch .Tensor , xi : torch .Tensor ) -> torch .Tensor :
301+ def get_Pfaffian_A (self , g : torch .Tensor , xi : torch .Tensor ) -> torch .Tensor :
287302 """Computes the Pfaffian A of the nonholonomic group."""
288303 raise NotImplementedError
289- # TODO: can this be made a property on subclass if they are constant???
290304
291305 def f (self , g : torch .Tensor , xi : torch .Tensor , u : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
292306 """Update using Euler-Poincare equations."""
293307 # _, xi_dot = super().f(g, xi, u)
294308 D = self .g_dim
295309 assert xi .ndim == 2 and xi .shape [1 ] == D , f"xi must be (N, { D } ), got { xi .shape } "
296310 assert u .shape [1 ] == self .u_dim and u .ndim == 2 , f"u must be (N, { self .u_dim } ), got { u .shape } "
297- # TODO: clean this and holonomic to use same interface as state-space dynamics (i.e. get forces, ...)
298311 Bu = u @ self .B_T # Bu in batched form
299312
300313 coad = self .coadjoint_operator (xi ) # (N, D, D)
@@ -303,4 +316,4 @@ def f(self, g: torch.Tensor, xi: torch.Tensor, u: torch.Tensor) -> Tuple[torch.T
303316 xi_dot = (self .inertia_matrix_inv @ rhs .T ).T
304317 xi_dot_proj = self .project_to_motion_constraints (g , xi_dot )
305318
306- return xi , xi_dot_proj
319+ return xi , xi_dot_proj
0 commit comments