Skip to content

Commit

Permalink
Type Annotations for tracker.py and iterativeclosestpoint.py (#896)
Browse files Browse the repository at this point in the history
* added tracker.py type annotation

* Ruff Fixes

* icp.py type annotated

* restore type ignore comment
  • Loading branch information
HossamSaberX authored Feb 23, 2025
1 parent 2d583d8 commit e1004e8
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 58 deletions.
21 changes: 13 additions & 8 deletions invesalius/navigation/iterativeclosestpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
# detalhes.
# --------------------------------------------------------------------------

from typing import TYPE_CHECKING

import numpy as np
import wx

Expand All @@ -25,19 +27,22 @@
import invesalius.session as ses
from invesalius.utils import Singleton

if TYPE_CHECKING:
from invesalius.navigation.navigation import Navigation
from invesalius.navigation.tracker import Tracker

class IterativeClosestPoint(metaclass=Singleton):
def __init__(self):
def __init__(self) -> None:
self.use_icp = False
self.m_icp = None
self.icp_fre = None

try:
self.LoadState()
except:
except: # noqa: E722
ses.Session().DeleteStateFile()

def SaveState(self):
def SaveState(self) -> None:
m_icp = self.m_icp.tolist() if self.m_icp is not None else None
state = {
"use_icp": self.use_icp,
Expand All @@ -48,7 +53,7 @@ def SaveState(self):
session = ses.Session()
session.SetState("icp", state)

def LoadState(self):
def LoadState(self) -> None:
session = ses.Session()
state = session.GetState("icp")

Expand All @@ -59,7 +64,7 @@ def LoadState(self):
self.m_icp = np.array(state["m_icp"])
self.icp_fre = state["icp_fre"]

def RegisterICP(self, navigation, tracker):
def RegisterICP(self, navigation: "Navigation", tracker: "Tracker") -> None:
# If ICP is already in use, return.
if self.use_icp:
return
Expand Down Expand Up @@ -97,17 +102,17 @@ def RegisterICP(self, navigation, tracker):

self.SetICP(navigation, self.use_icp)

def SetICP(self, navigation, use_icp):
def SetICP(self, navigation: "Navigation", use_icp: bool) -> None:
self.use_icp = use_icp

self.SaveState()

def ResetICP(self):
def ResetICP(self) -> None:
self.use_icp = False
self.m_icp = None
self.icp_fre = None

self.SaveState()

def GetFreForUI(self):
def GetFreForUI(self) -> str:
return f"{self.icp_fre:.2f}" if self.icp_fre else ""
104 changes: 54 additions & 50 deletions invesalius/navigation/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
# --------------------------------------------------------------------------

import threading
from typing import Dict, List, Optional, Tuple, cast

import numpy as np
from numpy.typing import NDArray # Added for proper ndarray type annotations

import invesalius.constants as const
import invesalius.data.coordinates as dco
Expand All @@ -35,37 +37,37 @@
# Only one tracker will be initialized per time. Therefore, we use
# Singleton design pattern for implementing it
class Tracker(metaclass=Singleton):
def __init__(self):
self.tracker_connection = None
self.tracker_id = const.DEFAULT_TRACKER
def __init__(self) -> None:
self.tracker_connection: Optional[tc.TrackerConnection] = None
self.tracker_id: int = const.DEFAULT_TRACKER

self.tracker_fiducials = np.full([3, 3], np.nan)
self.tracker_fiducials_raw = np.zeros((6, 6))
self.m_tracker_fiducials_raw = np.zeros((6, 4, 4))
self.tracker_fiducials: NDArray[np.float64] = np.full([3, 3], np.nan)
self.tracker_fiducials_raw: NDArray[np.float64] = np.zeros((6, 6))
self.m_tracker_fiducials_raw: NDArray[np.float64] = np.zeros((6, 4, 4))

self.tracker_connected = False
self.tracker_connected: bool = False

self.thread_coord = None
self.thread_coord: Optional[threading.Thread] = None

self.event_coord = threading.Event()
self.event_coord: threading.Event = threading.Event()

self.TrackerCoordinates = dco.TrackerCoordinates()
self.TrackerCoordinates: dco.TrackerCoordinates = dco.TrackerCoordinates()

try:
self.LoadState()
except:
except: # noqa: E722
ses.Session().DeleteStateFile()

def SaveState(self):
tracker_id = self.tracker_id
def SaveState(self) -> None:
tracker_id: int = self.tracker_id
tracker_fiducials = self.tracker_fiducials.tolist()
tracker_fiducials_raw = self.tracker_fiducials_raw.tolist()
marker_tracker_fiducials_raw = self.m_tracker_fiducials_raw.tolist()
configuration = (
configuration: Optional[Dict[str, object]] = (
self.tracker_connection.GetConfiguration() if self.tracker_connection else None
)

state = {
state: Dict[str, object] = {
"tracker_id": tracker_id,
"tracker_fiducials": tracker_fiducials,
"tracker_fiducials_raw": tracker_fiducials_raw,
Expand All @@ -75,18 +77,18 @@ def SaveState(self):
session = ses.Session()
session.SetState("tracker", state)

def LoadState(self):
def LoadState(self) -> None:
session = ses.Session()
state = session.GetState("tracker")

state: Optional[Dict[str, object]] = session.GetState("tracker")
if state is None:
return

tracker_id = state["tracker_id"]
tracker_fiducials = np.array(state["tracker_fiducials"])
tracker_fiducials_raw = np.array(state["tracker_fiducials_raw"])
m_tracker_fiducials_raw = np.array(state["marker_tracker_fiducials_raw"])
configuration = state["configuration"]
from typing import cast
tracker_id: int = cast(int, state["tracker_id"])
tracker_fiducials: NDArray[np.float64] = np.array(state["tracker_fiducials"])
tracker_fiducials_raw: NDArray[np.float64] = np.array(state["tracker_fiducials_raw"])
m_tracker_fiducials_raw: NDArray[np.float64] = np.array(state["marker_tracker_fiducials_raw"])
configuration: Optional[Dict[str, object]] = cast(Optional[Dict[str, object]], state["configuration"]) # Modified: cast configuration

self.tracker_id = tracker_id
self.tracker_fiducials = tracker_fiducials
Expand All @@ -95,15 +97,15 @@ def LoadState(self):

self.SetTracker(tracker_id=self.tracker_id, configuration=configuration)

def SetTracker(self, tracker_id, n_coils=1, configuration=None):
def SetTracker(self, tracker_id: int, n_coils: int = 1, configuration: Optional[Dict[str, object]] = None) -> None:
if tracker_id:
self.tracker_connection = tc.CreateTrackerConnection(tracker_id, n_coils)

# Configure tracker.
if configuration is not None:
success = self.tracker_connection.SetConfiguration(configuration)
success: bool = self.tracker_connection.SetConfiguration(configuration)
else:
success = self.tracker_connection.Configure()
success: bool = self.tracker_connection.Configure()

if not success:
self.tracker_connection = None
Expand All @@ -115,7 +117,7 @@ def SetTracker(self, tracker_id, n_coils=1, configuration=None):
# it happens with a different workflow than the other trackers. (See
# PolhemusTrackerConnection class for a more detailed explanation.)
if isinstance(self.tracker_connection, tc.PolhemusTrackerConnection):
reconfigure = configuration is None
reconfigure: bool = configuration is None
self.tracker_connection.Connect(reconfigure)
else:
self.tracker_connection.Connect()
Expand All @@ -135,11 +137,12 @@ def SetTracker(self, tracker_id, n_coils=1, configuration=None):
self.TrackerCoordinates,
self.event_coord,
)
self.thread_coord.start()
if self.thread_coord is not None:
self.thread_coord.start()

self.SaveState()

def DisconnectTracker(self):
def DisconnectTracker(self) -> None:
if self.tracker_connected:
Publisher.sendMessage("Update status text in GUI", label=_("Disconnecting tracker ..."))
Publisher.sendMessage("Remove sensors ID")
Expand All @@ -152,6 +155,7 @@ def DisconnectTracker(self):
self.thread_coord.join()
self.event_coord.clear()

assert self.tracker_connection is not None
self.tracker_connection.Disconnect()
if not self.tracker_connection.IsConnected():
self.tracker_connected = False
Expand All @@ -165,18 +169,18 @@ def DisconnectTracker(self):
)
print("Tracker still connected!")

def IsTrackerInitialized(self):
return self.tracker_connection and self.tracker_id and self.tracker_connected
def IsTrackerInitialized(self) -> bool:
return bool(self.tracker_connection and self.tracker_id and self.tracker_connected)

def IsTrackerFiducialSet(self, fiducial_index):
def IsTrackerFiducialSet(self, fiducial_index: int) -> bool:
return not np.isnan(self.tracker_fiducials)[fiducial_index].any()

def AreTrackerFiducialsSet(self):
def AreTrackerFiducialsSet(self) -> bool:
return not np.isnan(self.tracker_fiducials).any()

def GetTrackerCoordinates(self, ref_mode_id, n_samples=1):
coord_raw_samples = {}
coord_samples = {}
def GetTrackerCoordinates(self, ref_mode_id: int, n_samples: int = 1) -> Tuple[Tuple[bool, ...], NDArray[np.float64], NDArray[np.float64]]:
coord_raw_samples: Dict[int, NDArray[np.float64]] = {}
coord_samples: Dict[int, NDArray[np.float64]] = {}

for i in range(n_samples):
coord_raw, marker_visibilities = self.TrackerCoordinates.GetCoordinates()
Expand All @@ -190,12 +194,12 @@ def GetTrackerCoordinates(self, ref_mode_id, n_samples=1):
coord_raw_samples[i] = coord_raw
coord_samples[i] = coord

coord_raw_avg = np.median(list(coord_raw_samples.values()), axis=0)
coord_avg = np.median(list(coord_samples.values()), axis=0)
coord_raw_avg: NDArray[np.float64] = np.median(list(coord_raw_samples.values()), axis=0)
coord_avg: NDArray[np.float64] = np.median(list(coord_samples.values()), axis=0)

return marker_visibilities, coord_avg, coord_raw_avg

def SetTrackerFiducial(self, ref_mode_id, fiducial_index):
def SetTrackerFiducial(self, ref_mode_id: int, fiducial_index: int) -> bool:
marker_visibilities, coord, coord_raw = self.GetTrackerCoordinates(
ref_mode_id=ref_mode_id,
n_samples=const.CALIBRATION_TRACKER_SAMPLES,
Expand Down Expand Up @@ -233,37 +237,37 @@ def SetTrackerFiducial(self, ref_mode_id, fiducial_index):

return True

def ResetTrackerFiducials(self):
def ResetTrackerFiducials(self) -> None:
for m in range(3):
self.tracker_fiducials[m, :] = [np.nan, np.nan, np.nan]
Publisher.sendMessage("Reset tracker fiducials")
self.SaveState()

def GetTrackerFiducials(self):
def GetTrackerFiducials(self) -> Tuple[NDArray[np.float64], NDArray[np.float64]]:
return self.tracker_fiducials, self.tracker_fiducials_raw

def GetTrackerFiducialForUI(self, index, coordinate_index):
value = self.tracker_fiducials[index, coordinate_index]
def GetTrackerFiducialForUI(self, index: int, coordinate_index: int) -> float:
value: float = float(self.tracker_fiducials[index, coordinate_index])
if np.isnan(value):
value = 0

return value

def GetMatrixTrackerFiducials(self):
m_probe_ref_left = (
def GetMatrixTrackerFiducials(self) -> List[List[float]]:
m_probe_ref_left: NDArray[np.float64] = (
np.linalg.inv(self.m_tracker_fiducials_raw[1]) @ self.m_tracker_fiducials_raw[0]
)
m_probe_ref_right = (
m_probe_ref_right: NDArray[np.float64] = (
np.linalg.inv(self.m_tracker_fiducials_raw[3]) @ self.m_tracker_fiducials_raw[2]
)
m_probe_ref_nasion = (
m_probe_ref_nasion: NDArray[np.float64] = (
np.linalg.inv(self.m_tracker_fiducials_raw[5]) @ self.m_tracker_fiducials_raw[4]
)

return [m_probe_ref_left.tolist(), m_probe_ref_right.tolist(), m_probe_ref_nasion.tolist()]

def GetTrackerId(self):
def GetTrackerId(self) -> int:
return self.tracker_id

def get_trackers(self):
return const.TRACKERS
def get_trackers(self) -> List[str]:
return cast(List[str], const.TRACKERS)

0 comments on commit e1004e8

Please sign in to comment.