diff --git a/Controllers/__init__.py b/Controllers/__init__.py index 3a8e427..bb968f9 100644 --- a/Controllers/__init__.py +++ b/Controllers/__init__.py @@ -175,4 +175,4 @@ def update_logs(self, logging_values: "dict[str, TensorType]") -> None: if var is not None: self.logs[name].append( var.numpy().copy() if hasattr(var, "numpy") else var.copy() - ) \ No newline at end of file + ) diff --git a/Controllers/controller_fpga.py b/Controllers/controller_fpga.py index da85d09..157601e 100644 --- a/Controllers/controller_fpga.py +++ b/Controllers/controller_fpga.py @@ -1,6 +1,4 @@ import os -import sys -import glob import serial import struct import time @@ -11,12 +9,12 @@ from Control_Toolkit.Controllers import template_controller +from Control_Toolkit.serial_interface_helper import get_serial_port, set_ftdi_latency_timer + try: from SI_Toolkit_ASF.ToolkitCustomization.predictors_customization import STATE_INDICES except ModuleNotFoundError: - print("SI_Toolkit_ASF not yet created") - -from SI_Toolkit.Functions.General.Initialization import load_net_info_from_txt_file + raise Exception("SI_Toolkit_ASF not yet created") class controller_fpga(template_controller): @@ -27,58 +25,55 @@ def configure(self): SERIAL_PORT = get_serial_port(serial_port_number=self.config_controller["SERIAL_PORT"]) SERIAL_BAUD = self.config_controller["SERIAL_BAUD"] - set_ftdi_latency_timer(serial_port_number=self.config_controller["SERIAL_PORT"]) + set_ftdi_latency_timer(SERIAL_PORT=self.config_controller["SERIAL_PORT"]) self.InterfaceInstance = Interface() self.InterfaceInstance.open(SERIAL_PORT, SERIAL_BAUD) - NET_NAME = self.config_controller["net_name"] - PATH_TO_MODELS = self.config_controller["PATH_TO_MODELS"] - path_to_model_info = os.path.join(PATH_TO_MODELS, NET_NAME, NET_NAME + ".txt") - - self.input_at_input = self.config_controller["input_at_input"] + # --- PC↔SoC handshake: SoC declares input names and output count --- + self.spec_version, self.input_names, self.n_outputs = self.InterfaceInstance.get_spec() - self.net_info = load_net_info_from_txt_file(path_to_model_info) - - self.state_2_input_idx = [] - self.remaining_inputs = self.net_info.inputs.copy() - for key in self.net_info.inputs: - if key in STATE_INDICES.keys(): - self.state_2_input_idx.append(STATE_INDICES.get(key)) - self.remaining_inputs.remove(key) - else: - break # state inputs must be adjacent in the current implementation + self._state_idx = dict(STATE_INDICES) self.just_restarted = True - print('Configured fpga controller with {} network with {} library\n'.format(self.net_info.net_full_name, self.lib.lib)) + print('Configured SoC controller (spec v{}) with {} library\n'.format(self.spec_version, self.lib.lib)) - def step(self, s: np.ndarray, time=None, updated_attributes: "dict[str, TensorType]" = {}): + def step(self, s: np.ndarray, time=None, updated_attributes: "dict[str, TensorType]" = None): self.just_restarted = False - if self.input_at_input: - net_input = s - else: - self.update_attributes(updated_attributes) - net_input = s[..., self.state_2_input_idx] - for key in self.remaining_inputs: - net_input = np.append(net_input, getattr(self.variable_parameters, key)) - - net_input = self.lib.to_tensor(net_input, self.lib.float32) - - if self.lib.lib == 'Pytorch': - net_input = net_input.to(self.device) - - net_input = self.lib.reshape(net_input, (-1, 1, len(self.net_info.inputs))) - net_input = self.lib.to_numpy(net_input) - - net_output = self.get_net_output_from_fpga(net_input) + if updated_attributes is None: + updated_attributes = {} + self.update_attributes(updated_attributes) + + # Build inputs *exactly* in the wire order requested by the SoC. + # Precedence: updated_attributes > state vector > variable_parameters > 0.0 + arr = np.empty(len(self.input_names), dtype=np.float32) + for i, name in enumerate(self.input_names): + if name == "time": + if time is None: + raise Exception("Controller input 'time' is required but not provided.") + else: + val = float(time) # use simulator's timestamp (seconds, monotonic in sim time) + arr[i] = val + continue + + if name in updated_attributes: # external override wins + val = float(updated_attributes[name]) + elif name in self._state_idx: # pick from s by name→index map + val = float(s[..., self._state_idx[name]]) + elif hasattr(self, 'variable_parameters') and hasattr(self.variable_parameters, name): + val = float(getattr(self.variable_parameters, name)) + else: + val = 0.0 # explicit default to prevent UB + arr[i] = val - net_output = self.lib.to_tensor(net_output, self.lib.float32) - net_output = net_output[self.lib.newaxis, self.lib.newaxis, :] + controller_output = self.get_controller_output_from_fpga(arr) # raw float32 bytes over UART + controller_output = self.lib.to_tensor(controller_output, self.lib.float32) + controller_output = controller_output[self.lib.newaxis, self.lib.newaxis, :] if self.lib.lib == 'Pytorch': - net_output = net_output.detach().numpy() + controller_output = controller_output.detach().numpy() - Q = net_output + Q = controller_output return Q @@ -87,40 +82,30 @@ def controller_reset(self): if not self.just_restarted: self.configure() - def get_net_output_from_fpga(self, net_input): - self.InterfaceInstance.send_net_input(net_input) - net_output = self.InterfaceInstance.receive_net_output(len(self.net_info.outputs)) - return net_output + def get_controller_output_from_fpga(self, controller_input): + self.InterfaceInstance.send_controller_input(controller_input) + controller_output = self.InterfaceInstance.receive_controller_output(self.n_outputs) + + # if a cookie-triggered GET_SPEC happened, adopt it for NEXT step + if self.InterfaceInstance.pending_spec is not None: + self.spec_version, self.input_names, self.n_outputs = self.InterfaceInstance.pending_spec + self.InterfaceInstance.pending_spec = None + print(f"Refreshed SoC spec (v{self.spec_version}): " + f"{len(self.input_names)} inputs, {self.n_outputs} outputs") + return controller_output -def get_serial_port(serial_port_number=''): - import platform - import subprocess - serial_port_number = str(serial_port_number) - SERIAL_PORT = None - try: - system = platform.system() - if system == 'Darwin': # Mac - SERIAL_PORT = subprocess.check_output(f'ls -a /dev/tty.usbserial*{serial_port_number}', shell=True).decode("utf-8").strip() # Probably '/dev/tty.usbserial-110' - elif system == 'Linux': - SERIAL_PORT = '/dev/ttyUSB' + serial_port_number # You might need to change the USB number - elif system == 'Windows': - SERIAL_PORT = 'COM' + serial_port_number - else: - raise NotImplementedError('For system={} connection to serial port is not implemented.') - except Exception as err: - print(err) - return SERIAL_PORT PING_TIMEOUT = 1.0 # Seconds -CALIBRATE_TIMEOUT = 10.0 # Seconds READ_STATE_TIMEOUT = 1.0 # Seconds SERIAL_SOF = 0xAA CMD_PING = 0xC0 +CMD_GET_SPEC = 0xC6 +NAME_TOKEN_LEN = 16 # fixed ASCII token length per name class Interface: def __init__(self): @@ -131,11 +116,14 @@ def __init__(self): self.encoderDirection = None + self.pending_spec = None + def open(self, port, baud): self.port = port self.baud = baud self.device = serial.Serial(port, baudrate=baud, timeout=None) self.device.reset_input_buffer() + self.device.reset_output_buffer() def close(self): if self.device: @@ -150,27 +138,97 @@ def ping(self): msg = [SERIAL_SOF, CMD_PING, 4] msg.append(self._crc(msg)) self.device.write(bytearray(msg)) - return self._receive_reply(CMD_PING, 4, PING_TIMEOUT) == msg - - def send_net_input(self, net_input): + return self._receive_reply(4, PING_TIMEOUT) == msg + + def get_spec(self): + """ + Request SoC declaration of its input wire-order and output count. + + SoC reply (raw, no frame): 4-byte header + names block + byte 0: version (u8) + byte 1: n_inputs (u8) + byte 2: n_outputs (u8) + byte 3: token_len (u8) == NAME_TOKEN_LEN + bytes 4.. : n_inputs * token_len ASCII names (NUL-padded), wire order + """ + self.clear_read_buffer() + # Send framed request (SOF, CMD, LEN, CRC) to stay consistent with existing protocol. + msg = bytearray([SERIAL_SOF, CMD_GET_SPEC, 4]) + msg.append(self._crc(msg)) + self.device.write(msg) + + # Handshake is a control exchange: use a bounded timeout so we fail fast instead of hanging. + old_timeout = self.device.timeout + try: + self.device.timeout = 1.0 + hdr = self.device.read(4) + if len(hdr) != 4: + raise IOError("GET_SPEC: short header") + version, n_inputs, n_outputs, token_len = hdr[0], hdr[1], hdr[2], hdr[3] + if token_len != NAME_TOKEN_LEN: + raise IOError(f"GET_SPEC: unexpected token_len={token_len} (expected {NAME_TOKEN_LEN})") + + need = n_inputs * token_len + raw = self.device.read(need) + if len(raw) != need: + raise IOError("GET_SPEC: short names block") + + names = [] + for i in range(n_inputs): + chunk = raw[i*token_len:(i+1)*token_len] + # Cut at first NUL; ignore non-ASCII silently. + names.append(chunk.split(b'\x00', 1)[0].decode('ascii', errors='ignore')) + return version, names, n_outputs + finally: + self.device.timeout = old_timeout # restore streaming behavior + + def send_controller_input(self, controller_input): self.device.reset_output_buffer() - bytes_written = self.device.write(bytearray(net_input)) - # print(bytes_written) - - def receive_net_output(self, net_output_length): - net_output_length_bytes = net_output_length * 4 # We assume float32 - net_output = self.device.read(size=net_output_length_bytes) - net_output = struct.unpack(f'<{net_output_length}f', net_output) - # net_output=reply - return net_output + if not isinstance(controller_input, np.ndarray) or controller_input.dtype != np.float32: + controller_input = np.asarray(controller_input, dtype=np.float32) + self.device.write(controller_input.tobytes()) + + def receive_controller_output(self, controller_output_length): + """ + Reads controller outputs. If a spec-change cookie arrives, we immediately + re-handshake (GET_SPEC) for the next cycle, then still read and return + THIS cycle's outputs (old spec) so the control loop doesn't stall. + """ + # Peek first 4 bytes + head = self.device.read(size=4) + if len(head) != 4: + raise IOError(f"receive_controller_output: expected 4 bytes head, got {len(head)}") + + # Check for spec-change cookie: [SOF, CMD_SPEC_COOKIE, gen, CRC] + if head[0] == SERIAL_SOF and head[1] == 0xC7 and head[3] == self._crc(head[:3]): + # Re-handshake now so *next* step uses the new spec + version, names, n_outputs = self.get_spec() + # Stash for the controller to pick up after this receive + self.pending_spec = (version, names, n_outputs) + # Now read THIS cycle's outputs (old spec) and return them + nbytes = controller_output_length * 4 + data = self.device.read(size=nbytes) + if len(data) != nbytes: + raise IOError(f"receive_controller_output: expected {nbytes} bytes after cookie, got {len(data)}") + return struct.unpack(f'<{controller_output_length}f', data) + + # No cookie: head belongs to outputs; read the rest + rest_bytes = controller_output_length * 4 - 4 + if rest_bytes < 0: + raise ValueError("controller_output_length must be >= 1") + rest = self.device.read(size=rest_bytes) if rest_bytes else b"" + if len(rest) != rest_bytes: + raise IOError(f"receive_controller_output: expected {rest_bytes} tail bytes, got {len(rest)}") + data = head + rest + return struct.unpack(f'<{controller_output_length}f', data) def _receive_reply(self, cmdLen, timeout=None, crc=True): self.device.timeout = timeout self.start = False + self.msg = [] while True: - c = self.device.read() - # Timeout: reopen device, start stream, reset msg and try again + c = self.device.read(1) if len(c) == 0: print('\nReconnecting.') self.device.close() @@ -180,8 +238,9 @@ def _receive_reply(self, cmdLen, timeout=None, crc=True): self.msg = [] self.start = False else: - self.msg.append(ord(c)) - if self.start == False: + # Py3: bytes→int via c[0]; ord() on bytes is a TypeError. + self.msg.append(c[0]) + if self.start is False: self.start = time.time() while len(self.msg) >= cmdLen: @@ -222,26 +281,3 @@ def _crc(self, msg): val >>= 1 return crc8 - - -import subprocess -def set_ftdi_latency_timer(serial_port_number): - print('\nSetting FTDI latency timer') - ftdi_timer_latency_requested_value = 1 - command_ftdi_timer_latency_set = f"sh -c 'echo {ftdi_timer_latency_requested_value} > /sys/bus/usb-serial/devices/ttyUSB{serial_port_number}/latency_timer'" - command_ftdi_timer_latency_check = f'cat /sys/bus/usb-serial/devices/ttyUSB{serial_port_number}/latency_timer' - try: - subprocess.run(command_ftdi_timer_latency_set, shell=True, check=True, capture_output=True, text=True) - except subprocess.CalledProcessError as e: - print(e.stderr) - if "Permission denied" in e.stderr: - print("Trying with sudo...") - command_ftdi_timer_latency_set = "sudo " + command_ftdi_timer_latency_set - try: - subprocess.run("echo Teresa | sudo -S :", shell=True) - subprocess.run(command_ftdi_timer_latency_set, shell=True, check=True, capture_output=True, text=True) - except subprocess.CalledProcessError as e: - print(e.stderr) - - ftdi_latency_timer_value = subprocess.run(command_ftdi_timer_latency_check, shell=True, capture_output=True, text=True).stdout.rstrip() - print(f'FTDI latency timer value (tested only for FTDI with Zybo and with Linux on PC side): {ftdi_latency_timer_value} ms \n') diff --git a/Controllers/controller_neural_imitator.py b/Controllers/controller_neural_imitator.py index c019487..4efdcdd 100644 --- a/Controllers/controller_neural_imitator.py +++ b/Controllers/controller_neural_imitator.py @@ -22,7 +22,9 @@ def configure(self): path_to_models=self.config_controller["PATH_TO_MODELS"], batch_size=1, # It makes sense only for testing (Brunton plot for Q) of not rnn networks to make bigger batch, this is not implemented input_precision=self.config_controller["input_precision"], - hls4ml=self.config_controller["hls4ml"]) + nn_evaluator_mode=self.config_controller["nn_evaluator_mode"]) + + self.clip_output = self.config_controller.get("clip_output", False) self._computation_library = self.net_evaluator.lib @@ -31,7 +33,7 @@ def configure(self): # Prepare input mapping self.input_mapping = self._create_input_mapping() - if self.controller_logging and self.lib.lib == "TF" and not self.net_evaluator.hls4ml: + if self.controller_logging and self.lib.lib == "TF" and self.net_evaluator.nn_evaluator_mode == 'normal': self.controller_data_for_csv = FunctionalDict(get_memory_states(self.net_evaluator.net)) print('Configured neural imitator with {} network with {} library'.format(self.net_evaluator.net_info.net_full_name, self.net_evaluator.net_info.library)) @@ -61,7 +63,8 @@ def step(self, s: np.ndarray, time=None, updated_attributes: "dict[str, TensorTy Q = self.net_evaluator.step(net_input) - Q = np.clip(Q, -1.0, 1.0) # Ensure Q is within the range [-1, 1] + if self.clip_output: + Q = np.clip(Q, -1.0, 1.0) # Ensure Q is within the range [-1, 1] return Q diff --git a/Controllers/controller_remote.py b/Controllers/controller_remote.py new file mode 100644 index 0000000..947cffc --- /dev/null +++ b/Controllers/controller_remote.py @@ -0,0 +1,124 @@ +from __future__ import annotations +import numpy as np +import zmq +import zmq.error + +from SI_Toolkit.computation_library import NumpyLibrary +from Control_Toolkit.Controllers import template_controller +from Control_Toolkit.others.globals_and_utils import import_controller_by_name + +ENFORCE_TIMEOUT = True # Set to False to disable the timeout feature +DEFAULT_RCVTIMEO = 50 # [ms] + + +class controller_remote(template_controller): + _computation_library = NumpyLibrary() + """ + ZeroMQ DEALER proxy. + • Sends each state to the server together with a monotonically + increasing *request-id* (`rid`). + • Drops or purges every reply whose rid ≠ last request’s rid. + • After a timeout the motor command falls back to 0 or to a local controller. + """ + + def configure(self): + # ─── remote socket setup ──────────────────────────────────────── + self.endpoint = self.config_controller.get( + "remote_endpoint", "tcp://localhost:5555" + ) + self._ctx = zmq.Context() + self._sock = self._ctx.socket(zmq.DEALER) + self._sock.connect(self.endpoint) + if ENFORCE_TIMEOUT: + self._sock.setsockopt(zmq.RCVTIMEO, DEFAULT_RCVTIMEO) + + self._next_rid: int = 0 + print(f"Neural-imitator proxy connected to {self.endpoint}") + + # ─── fallback to a local controller or 0 control ────────────────────── + # retrieve fallback-controller parameters from config + self.fallback_controller_name = self.config_controller["fallback_controller_name"] + + if self.fallback_controller_name is not None: + # dynamically import and instantiate the local controller + # e.g. import_controller_by_name("controller-neural-imitator") + Controller = import_controller_by_name( + f"controller-{self.fallback_controller_name}".replace("-", "_") + ) + self._fallback_controller = Controller( + self.environment_name, self.control_limits, self.initial_environment_attributes + ) + self._fallback_controller.configure() + + # ------------------------------------------------------------------ STEP + def step( + self, + s: np.ndarray, + time=None, + updated_attributes: "dict[str, np.ndarray]" = {}, + ): + """ + Serialises the data, ships it to the server, waits up to 50 ms for Q, + and returns it—or falls back on timeout to the fallback controller or zero control. + """ + if updated_attributes is None: + updated_attributes = {} + + rid = self._next_rid # snapshot current rid + self._next_rid += 1 # prepare for next call + + self._sock.send_json( + { + "rid": rid, + "state": s.tolist(), # JSON-friendly + "time": time, + "updated_attributes": updated_attributes, + } + ) + + # ❷ -- receive with timeout + try: + resp = self._sock.recv_json() # may raise zmq.Again + except zmq.error.Again: + self._purge_stale() # clear the queue + if self.fallback_controller_name is not None: + # use local controller on timeout + return self._fallback_controller.step( + s, time=time, updated_attributes=updated_attributes + ) + return np.array(0.0, dtype=np.float32) + + # —— discard stale packets —————————— + while resp.get("rid") != rid: + try: + resp = self._sock.recv_json() + except zmq.error.Again: + # genuine timeout – treat as lost reply + if self.fallback_controller_name is not None: + return self._fallback_controller.step( + s, time=time, updated_attributes=updated_attributes + ) + return np.array(0.0, dtype=np.float32) + + if "error" in resp: + # Re-raise server-side exceptions locally for easier debugging + raise RuntimeError(f"Remote controller error: {resp['error']}") + + # ❸ -- final result + return np.asarray(resp["Q"], dtype=np.float32) + + # ---------------------------------------------------------- helpers + def _purge_stale(self) -> None: + """Discard every pending message in the inbound queue.""" + while True: + try: + self._sock.recv(flags=zmq.DONTWAIT) + except zmq.error.Again: + break + + # ---------------------------------------------------------------- RESET + def controller_reset(self): + """ + Nothing to reset locally; the server keeps the network state. + """ + pass diff --git a/controller_server/__init__.py b/controller_server/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/controller_server/controller_server.py b/controller_server/controller_server.py new file mode 100644 index 0000000..08b8d62 --- /dev/null +++ b/controller_server/controller_server.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +""" +remote_nn_controller_server.py + +ZeroMQ ROUTER server that uses gui_selection to pick controller and optimizer, +then serves step requests. +""" + +import sys +import numpy as np +import zmq +import json + +from Control_Toolkit.controller_server.gui import choose_controller_and_optimizer +from Control_Toolkit.others.globals_and_utils import import_controller_by_name + +# Hardcoded ZeroMQ endpoint +ENDPOINT = "tcp://*:5555" + + +initial_environment_attributes = { + "target_position": 0.0, + "target_equilibrium": 0.0, + "m_pole": 0.0, + "L": 0.0, + "Q_ccrc": 0.0, + "Q_applied_-1": 0.0, +} + +def main(): + # Launch the GUI to get controller/optimizer + ctrl_name, opt_name = choose_controller_and_optimizer() + print(f"[server] ▶️ Controller: {ctrl_name} Optimizer: {opt_name}") + + # Dynamically import & instantiate + ControllerClass = import_controller_by_name(ctrl_name) + ctrl = ControllerClass( + environment_name="CartPole", + control_limits=(-1.0, 1.0), + initial_environment_attributes=initial_environment_attributes, # populate as needed + ) + + # Configure with or without optimizer + if ctrl.has_optimizer: + ctrl.configure(optimizer_name=opt_name) + else: + ctrl.configure() + + # ─── ZeroMQ ROUTER socket ──────────────────────────────────────── + ctx = zmq.Context() + sock = ctx.socket(zmq.ROUTER) + sock.bind(ENDPOINT) + print(f"[server] 🚀 listening on {ENDPOINT}") + + while True: + # Receive either [identity, payload] or [identity, b"", payload] + parts = sock.recv_multipart() + if len(parts) == 2: + client_identity, payload = parts + elif len(parts) == 3 and parts[1] == b"": + client_identity, _empty, payload = parts + else: + # Unexpected framing; skip it + continue + + try: + req = json.loads(payload.decode("utf-8")) + rid = req["rid"] + s = np.asarray(req["state"], dtype=np.float32) + t = req.get("time") + upd = req.get("updated_attributes", {}) + + Q = ctrl.step(s, t, upd) + if isinstance(Q, np.ndarray): + Q_payload = Q.tolist() + else: + # covers Python floats *and* tf.Tensor scalars via .numpy() + Q_payload = float(Q) if not isinstance(Q, (list, tuple)) else Q + + reply = json.dumps({"rid": rid, "Q": Q_payload}).encode("utf-8") + + sock.send_multipart([client_identity, reply]) + + except Exception as e: + print(f"[server] ⚠️ controller exception – no reply sent: {e}", file=sys.stderr) + continue # do NOT send anything back + + +if __name__ == "__main__": + main() diff --git a/controller_server/gui.py b/controller_server/gui.py new file mode 100644 index 0000000..d9c9cb3 --- /dev/null +++ b/controller_server/gui.py @@ -0,0 +1,89 @@ +from PyQt6.QtWidgets import ( + QApplication, + QDialog, + QVBoxLayout, + QGroupBox, + QRadioButton, + QDialogButtonBox, +) +from PyQt6.QtCore import Qt + +from Control_Toolkit.others.globals_and_utils import ( + get_available_controller_names, + get_available_optimizer_names, + get_controller_name, + get_optimizer_name, +) + + +class SelectionDialog(QDialog): + def __init__(self): + super().__init__() + self.setWindowTitle("Select Controller & Optimizer") + self.resize(400, 300) + + layout = QVBoxLayout(self) + + # Controllers group + ctrl_names = get_available_controller_names() + box_ctrl = QGroupBox("Controllers") + vbox_ctrl = QVBoxLayout() + self.rbs_controllers = [] + for name in ctrl_names: + rb = QRadioButton(name) + vbox_ctrl.addWidget(rb) + self.rbs_controllers.append(rb) + if self.rbs_controllers: + self.rbs_controllers[0].setChecked(True) + box_ctrl.setLayout(vbox_ctrl) + layout.addWidget(box_ctrl) + + # Optimizers group + opt_names = get_available_optimizer_names() + box_opt = QGroupBox("Optimizers") + vbox_opt = QVBoxLayout() + self.rbs_optimizers = [] + for name in opt_names: + rb = QRadioButton(name) + vbox_opt.addWidget(rb) + self.rbs_optimizers.append(rb) + if self.rbs_optimizers: + self.rbs_optimizers[0].setChecked(True) + box_opt.setLayout(vbox_opt) + layout.addWidget(box_opt) + + # OK / Cancel buttons + btns = QDialogButtonBox( + QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel, + Qt.Orientation.Horizontal, + self, + ) + btns.accepted.connect(self.accept) + btns.rejected.connect(self.reject) + layout.addWidget(btns) + + def get_selection(self): + """ + Returns: + Tuple[str, str]: (controller_name, optimizer_name) + """ + ctrl = None + for idx, rb in enumerate(self.rbs_controllers): + if rb.isChecked(): + ctrl, _ = get_controller_name(controller_idx=idx) + break + opt = None + for idx, rb in enumerate(self.rbs_optimizers): + if rb.isChecked(): + opt, _ = get_optimizer_name(optimizer_idx=idx) + break + return ctrl, opt + + +def choose_controller_and_optimizer(): + import sys + app = QApplication(sys.argv) + dlg = SelectionDialog() + if dlg.exec() != QDialog.DialogCode.Accepted: + sys.exit(0) + return dlg.get_selection() diff --git a/serial_interface_helper.py b/serial_interface_helper.py new file mode 100644 index 0000000..95457c9 --- /dev/null +++ b/serial_interface_helper.py @@ -0,0 +1,95 @@ +import getpass +import platform +import subprocess + +import serial + +SUDO_PASSWORD = None # Required to set FTDI latency timer on Linux systems, can be set to a hardcoded password for convenience or left as None to prompt the user via terminal. + +def get_serial_port(chip_type="STM", serial_port_number=None): + """ + Finds the cartpole serial port, or throws exception if not present + :param chip_type: "ZYNQ" or "STM" depending on which one you use + :param serial_port_number: Only used if serial port not found using chip type, can be left None, for normal operation + :returns: the string name of the COM port + """ + + from serial.tools import list_ports + ports = list(serial.tools.list_ports.comports()) + serial_ports_names = [] + print('\nAvailable serial ports:') + for index, port in enumerate(ports): + serial_ports_names.append(port.device) + print(f'{index}: port={port.device}; description={port.description}') + print() + + if chip_type == "STM": + expected_descriptions = ['USB Serial'] + elif chip_type == "ZYNQ": + expected_descriptions = ['Digilent Adept USB Device - Digilent Adept USB Device', 'Digilent Adept USB Device'] + else: + raise ValueError(f'Unknown chip type: {chip_type}') + + possible_ports = [] + for port in ports: + if port.description in expected_descriptions: + possible_ports.append(port.device) + + SERIAL_PORT = None + if not possible_ports: + message = f"Searching serial port by its expected descriptions - {expected_descriptions} - not successful." + if serial_port_number is not None: + print(message) + else: + raise Exception(message) + else: + if serial_port_number < len(possible_ports): + SERIAL_PORT = possible_ports[serial_port_number] + else: + print( + f"Requested serial port number {serial_port_number} is out of range. Available ports: {len(possible_ports)}") + print(f"Using the first available port: {possible_ports[0]}") + SERIAL_PORT = possible_ports[0] + + if SERIAL_PORT is None and serial_port_number is not None: + if len(serial_ports_names) == 0: + print(f'No serial ports') + else: + print(f"Setting serial port with requested number ({serial_port_number})\n") + SERIAL_PORT = serial_ports_names[serial_port_number] + + return SERIAL_PORT + + +def set_ftdi_latency_timer(SERIAL_PORT): + print('\nSetting FTDI latency timer') + requested_value = 1 # in ms + + if platform.system() == 'Linux': + # check for hardcoded sudo password or prompt the user + if SUDO_PASSWORD: + password = SUDO_PASSWORD + else: + password = getpass.getpass('Enter sudo password: ') + + serial_port = SERIAL_PORT.split('/')[-1] + ftdi_timer_latency_requested_value = 1 + command_ftdi_timer_latency_set = f"sh -c 'echo {ftdi_timer_latency_requested_value} > /sys/bus/usb-serial/devices/{serial_port}/latency_timer'" + command_ftdi_timer_latency_check = f'cat /sys/bus/usb-serial/devices/{serial_port}/latency_timer' + try: + subprocess.run(command_ftdi_timer_latency_set, shell=True, check=True, capture_output=True, text=True) + except subprocess.CalledProcessError as e: + print(e.stderr) + if "Permission denied" in e.stderr: + print("Trying with sudo...") + command_ftdi_timer_latency_set = f"echo {password} | sudo -S {command_ftdi_timer_latency_set}" + try: + subprocess.run(command_ftdi_timer_latency_set, shell=True, check=True, capture_output=True, + text=True) + except subprocess.CalledProcessError as e: + print(e.stderr) + + ftdi_latency_timer_value = subprocess.run(command_ftdi_timer_latency_check, shell=True, capture_output=True, + text=True).stdout.rstrip() + print( + f'FTDI latency timer value (tested only for FTDI with Zybo and with Linux on PC side): {ftdi_latency_timer_value} ms \n')