Skip to content

Commit

Permalink
Add support for TCP backseat driver
Browse files Browse the repository at this point in the history
  • Loading branch information
oysstu committed Sep 3, 2024
1 parent 4d97eee commit 6622a02
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 10 deletions.
42 changes: 32 additions & 10 deletions imcpy/actors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import imcpy
from imcpy.decorators import *
from imcpy.exception import AmbiguousKeyError
from imcpy.network.tcp import IMCProtocolTCPServer
from imcpy.network.udp import IMCProtocolUDP, IMCSenderUDP, get_imc_socket, get_multicast_socket
from imcpy.node import IMCNode, IMCService

Expand All @@ -26,18 +27,28 @@ class IMCBase:
Implements an event loop, subscriptions, IMC node bookkeeping
"""

def __init__(self, imc_id=0x3334, static_port=None, verbose_nodes=False, log_enable=False, log_root=None):
def __init__(
self,
imc_id: int = 0x3334,
static_port: Optional[int] = None,
backseat_port: Optional[int] = None,
verbose_nodes=False,
log_enable=False,
log_root: Optional[str] = None,
):
"""
Initialize the IMC comms. Does not start the event loop until run() is called
:param imc_id: The IMC address this node should operate under
:param imc_id: The IMC id this node should operate under
:param static_port: Optional static port to listen for IMC messages (useful if DUNE uses static transports)
:param backseat_port: Listen for TCP+IMC connections from backseat on this port (similar to DUNE backseat).
:param verbose_nodes: If true, the connected nodes are printed out every 10 seconds
:param log_enable: Enable logging of incoming and outgoing IMC messages (.lsf)
:param log_dir: Root directory for IMC logs (default: /tmp/, or equivalent)
:param log_root: Root directory for IMC logs (default: /tmp/, or equivalent)
"""
# Arguments
self.imc_id = imc_id
self.static_port = static_port
self.backseat_port: Optional[int] = backseat_port
self.verbose_nodes = verbose_nodes
self.log_enable = log_enable
self.log_root = os.path.join(tempfile.gettempdir(), 'imcpy') if log_root is None else log_root
Expand All @@ -56,6 +67,8 @@ def __init__(self, imc_id=0x3334, static_port=None, verbose_nodes=False, log_ena
self._loop: Optional[asyncio.BaseEventLoop] = None
self._task_mc: Optional[asyncio.Task] = None
self._task_imc: Optional[asyncio.Task] = None
self._task_server: Optional[asyncio.Task] = None
self._server_backseat: Optional[IMCProtocolTCPServer] = None
self._subs: Dict[Type[imcpy.Message], List[types.MethodType]] = {}

# IMC/Multicast ports (assigned when socket is created)
Expand All @@ -67,6 +80,7 @@ def __init__(self, imc_id=0x3334, static_port=None, verbose_nodes=False, log_ena

# Static transports
# Adding imcpy.Message transports all messages
# TODO: add client API for adding and removing static transports
self._static_transports: Dict[Type[imcpy.Message], List[IMCService]] = {}

# Runtime data
Expand Down Expand Up @@ -336,6 +350,10 @@ def send(self, node_id, msg, set_timestamp=True):
# Send to static destinations
self.send_static(msg)

# Send to connected backseat (if any)
if self._backseat_server is not None and self._loop is not None:
asyncio.ensure_future(self._backseat_server.write_message(msg), loop=self._loop)

def on_exception(self, loc, exc):
"""
Can be overridden in subclasses to handle uncaught exceptions in @Subscribe, @Periodic, @RunOnce functions
Expand All @@ -351,22 +369,26 @@ def _start_subscriptions(self):
"""
Add asyncio datagram endpoint for all subscriptions
"""
assert self._loop is not None, 'Event loop not initialized'

# Add datagram endpoint for multicast announce
mc_sock = get_multicast_socket()
self._port_mc = mc_sock.getsockname()[1]
multicast_listener = self._loop.create_datagram_endpoint(lambda: IMCProtocolUDP(self), sock=mc_sock)
self._task_mc = asyncio.ensure_future(multicast_listener, loop=self._loop)

# Add datagram endpoint for UDP IMC messages
imc_sock = get_imc_socket(static_port=self.static_port)
self._port_imc = imc_sock.getsockname()[1]
imc_listener = self._loop.create_datagram_endpoint(lambda: IMCProtocolUDP(self), sock=imc_sock)

if sys.version_info < (3, 4, 4):
self._task_mc = self._loop.create_task(multicast_listener)
self._task_imc = self._loop.create_task(imc_listener)
else:
self._task_mc = asyncio.ensure_future(multicast_listener, loop=self._loop)
self._task_imc = asyncio.ensure_future(imc_listener, loop=self._loop)
self._task_imc = asyncio.ensure_future(imc_listener, loop=self._loop)

# Add backseat TCP server if port is set
if self.backseat_port is not None:
self._backseat_server = IMCProtocolTCPServer(self)
logger.info(f'Setting up backseat server on 127.0.0.1:{self.backseat_port}')
backseat_co = asyncio.start_server(self._backseat_server.on_connection, '127.0.0.1', self.backseat_port)
self._task_backseat = asyncio.ensure_future(backseat_co, loop=self._loop)

def _setup_event_loop(self):
"""
Expand Down
1 change: 1 addition & 0 deletions imcpy/network/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
import imcpy.network.tcp
import imcpy.network.udp
import imcpy.network.utils
72 changes: 72 additions & 0 deletions imcpy/network/tcp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import asyncio
import logging

import imcpy

logger = logging.getLogger('imcpy.tcp')


class IMCProtocolTCPClientConnection:
def __init__(self, name: str, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
self.name = name
self.reader = reader
self.writer = writer
self._parser = imcpy.Parser()

async def write_bytes(self, data: bytes):
"""Write the bytes to the connected clients and flush buffer if necessary.
:param data: The serialized IMC message to write.
"""
self.writer.write(data)
await self.writer.drain()

async def close(self):
self.writer.close()
await self.writer.wait_closed()

async def handle_data(self, instance):
"""Handle incoming data from the client."""
try:
while True:
data = await self.reader.read(4096)
if not data:
logger.error(f'Connection closed by peer ({self.name})')
await self.close()
return

data_remaining = len(data)
while data_remaining > 0:
msg, parsed_bytes = self._parser.parse(data[-data_remaining:])
data_remaining -= parsed_bytes
if msg is not None:
# Log IMC message to file if enabled
if instance.log_imc_fh and not instance.log_imc_fh.closed:
instance.log_imc_fh.write(data)

instance.post_message(msg)
except ConnectionError as e:
logger.error(f'Connection error ({self.name}): {e}')
await self.close()


class IMCProtocolTCPServer:
def __init__(self, instance) -> None:
self.instance = instance
self._clients = set()

async def on_connection(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
logger.info(f'New connection from {writer.get_extra_info("peername")}')
client = IMCProtocolTCPClientConnection(writer.get_extra_info('peername'), reader, writer)
self._clients.add(client)
await client.handle_data(self.instance)
self._clients.remove(client)

async def write_message(self, msg: imcpy.Message):
"""Write message to all connected clients.
:param msg: The IMC message to write.
"""
b = msg.serialize()
for client in self._clients:
await client.write_bytes(b)

0 comments on commit 6622a02

Please sign in to comment.