diff --git a/can/interfaces/slcan.py b/can/interfaces/slcan.py index 4cf8b5e25..f0a04e305 100644 --- a/can/interfaces/slcan.py +++ b/can/interfaces/slcan.py @@ -6,7 +6,8 @@ import logging import time import warnings -from typing import Any, Optional, Tuple, Union +from queue import SimpleQueue +from typing import Any, Optional, Tuple, Union, cast from can import BitTiming, BitTimingFd, BusABC, CanProtocol, Message, typechecking from can.exceptions import ( @@ -131,6 +132,7 @@ def __init__( timeout=timeout, ) + self._queue: SimpleQueue[str] = SimpleQueue() self._buffer = bytearray() self._can_protocol = CanProtocol.CAN_20 @@ -196,7 +198,7 @@ def _read(self, timeout: Optional[float]) -> Optional[str]: # We read the `serialPortOrig.in_waiting` only once here. in_waiting = self.serialPortOrig.in_waiting for _ in range(max(1, in_waiting)): - new_byte = self.serialPortOrig.read(size=1) + new_byte = self.serialPortOrig.read(1) if new_byte: self._buffer.extend(new_byte) else: @@ -234,7 +236,10 @@ def _recv_internal( extended = False data = None - string = self._read(timeout) + if self._queue.qsize(): + string: Optional[str] = self._queue.get_nowait() + else: + string = self._read(timeout) if not string: pass @@ -300,7 +305,7 @@ def shutdown(self) -> None: def fileno(self) -> int: try: - return self.serialPortOrig.fileno() + return cast(int, self.serialPortOrig.fileno()) except io.UnsupportedOperation: raise NotImplementedError( "fileno is not implemented using current CAN bus on this platform" @@ -321,19 +326,21 @@ def get_version( int hw_version is the hardware version or None on timeout int sw_version is the software version or None on timeout """ + _timeout = serial.Timeout(timeout) cmd = "V" self._write(cmd) - string = self._read(timeout) - - if not string: - pass - elif string[0] == cmd and len(string) == 6: - # convert ASCII coded version - hw_version = int(string[1:3]) - sw_version = int(string[3:5]) - return hw_version, sw_version - + while True: + if string := self._read(_timeout.time_left()): + if string[0] == cmd: + # convert ASCII coded version + hw_version = int(string[1:3]) + sw_version = int(string[3:5]) + return hw_version, sw_version + else: + self._queue.put_nowait(string) + if _timeout.expired(): + break return None, None def get_serial_number(self, timeout: Optional[float]) -> Optional[str]: @@ -345,15 +352,17 @@ def get_serial_number(self, timeout: Optional[float]) -> Optional[str]: :return: :obj:`None` on timeout or a :class:`str` object. """ + _timeout = serial.Timeout(timeout) cmd = "N" self._write(cmd) - string = self._read(timeout) - - if not string: - pass - elif string[0] == cmd and len(string) == 6: - serial_number = string[1:-1] - return serial_number - + while True: + if string := self._read(_timeout.time_left()): + if string[0] == cmd: + serial_number = string[1:-1] + return serial_number + else: + self._queue.put_nowait(string) + if _timeout.expired(): + break return None diff --git a/pyproject.toml b/pyproject.toml index f2b6ac04f..4ee463c5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,7 +122,6 @@ exclude = [ "^can/interfaces/neousys", "^can/interfaces/pcan", "^can/interfaces/serial", - "^can/interfaces/slcan", "^can/interfaces/socketcan", "^can/interfaces/systec", "^can/interfaces/udp_multicast", diff --git a/test/test_slcan.py b/test/test_slcan.py index e1531e500..220a6d7e0 100644 --- a/test/test_slcan.py +++ b/test/test_slcan.py @@ -1,9 +1,9 @@ #!/usr/bin/env python -import unittest -from typing import cast +import unittest.mock +from typing import cast, Optional -import serial +from serial.serialutil import SerialBase import can.interfaces.slcan @@ -21,20 +21,69 @@ TIMEOUT = 0.5 if IS_PYPY else 0.01 # 0.001 is the default set in slcanBus +class SerialMock(SerialBase): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self._input_buffer = b"" + self._output_buffer = b"" + + def open(self) -> None: + self.is_open = True + + def close(self) -> None: + self.is_open = False + self._input_buffer = b"" + self._output_buffer = b"" + + def read(self, size: int = -1, /) -> bytes: + if size > 0: + data = self._input_buffer[:size] + self._input_buffer = self._input_buffer[size:] + return data + return b"" + + def write(self, b: bytes, /) -> Optional[int]: + self._output_buffer = b + if b == b"N\r": + self.set_input_buffer(b"NA123\r") + elif b == b"V\r": + self.set_input_buffer(b"V1013\r") + return len(b) + + def set_input_buffer(self, expected: bytes) -> None: + self._input_buffer = expected + + def get_output_buffer(self) -> bytes: + return self._output_buffer + + def reset_input_buffer(self) -> None: + self._input_buffer = b"" + + @property + def in_waiting(self) -> int: + return len(self._input_buffer) + + @classmethod + def serial_for_url(cls, *args, **kwargs) -> SerialBase: + return cls(*args, **kwargs) + + class slcanTestCase(unittest.TestCase): + @unittest.mock.patch("serial.serial_for_url", SerialMock.serial_for_url) def setUp(self): self.bus = cast( can.interfaces.slcan.slcanBus, can.Bus("loop://", interface="slcan", sleep_after_open=0, timeout=TIMEOUT), ) - self.serial = cast(serial.Serial, self.bus.serialPortOrig) + self.serial = cast(SerialMock, self.bus.serialPortOrig) self.serial.reset_input_buffer() def tearDown(self): self.bus.shutdown() def test_recv_extended(self): - self.serial.write(b"T12ABCDEF2AA55\r") + self.serial.set_input_buffer(b"T12ABCDEF2AA55\r") msg = self.bus.recv(TIMEOUT) self.assertIsNotNone(msg) self.assertEqual(msg.arbitration_id, 0x12ABCDEF) @@ -44,7 +93,7 @@ def test_recv_extended(self): self.assertSequenceEqual(msg.data, [0xAA, 0x55]) # Ewert Energy Systems CANDapter specific - self.serial.write(b"x12ABCDEF2AA55\r") + self.serial.set_input_buffer(b"x12ABCDEF2AA55\r") msg = self.bus.recv(TIMEOUT) self.assertIsNotNone(msg) self.assertEqual(msg.arbitration_id, 0x12ABCDEF) @@ -54,15 +103,19 @@ def test_recv_extended(self): self.assertSequenceEqual(msg.data, [0xAA, 0x55]) def test_send_extended(self): + payload = b"T12ABCDEF2AA55\r" msg = can.Message( arbitration_id=0x12ABCDEF, is_extended_id=True, data=[0xAA, 0x55] ) self.bus.send(msg) + self.assertEqual(payload, self.serial.get_output_buffer()) + + self.serial.set_input_buffer(payload) rx_msg = self.bus.recv(TIMEOUT) self.assertTrue(msg.equals(rx_msg, timestamp_delta=None)) def test_recv_standard(self): - self.serial.write(b"t4563112233\r") + self.serial.set_input_buffer(b"t4563112233\r") msg = self.bus.recv(TIMEOUT) self.assertIsNotNone(msg) self.assertEqual(msg.arbitration_id, 0x456) @@ -72,15 +125,19 @@ def test_recv_standard(self): self.assertSequenceEqual(msg.data, [0x11, 0x22, 0x33]) def test_send_standard(self): + payload = b"t4563112233\r" msg = can.Message( arbitration_id=0x456, is_extended_id=False, data=[0x11, 0x22, 0x33] ) self.bus.send(msg) + self.assertEqual(payload, self.serial.get_output_buffer()) + + self.serial.set_input_buffer(payload) rx_msg = self.bus.recv(TIMEOUT) self.assertTrue(msg.equals(rx_msg, timestamp_delta=None)) def test_recv_standard_remote(self): - self.serial.write(b"r1238\r") + self.serial.set_input_buffer(b"r1238\r") msg = self.bus.recv(TIMEOUT) self.assertIsNotNone(msg) self.assertEqual(msg.arbitration_id, 0x123) @@ -89,15 +146,19 @@ def test_recv_standard_remote(self): self.assertEqual(msg.dlc, 8) def test_send_standard_remote(self): + payload = b"r1238\r" msg = can.Message( arbitration_id=0x123, is_extended_id=False, is_remote_frame=True, dlc=8 ) self.bus.send(msg) + self.assertEqual(payload, self.serial.get_output_buffer()) + + self.serial.set_input_buffer(payload) rx_msg = self.bus.recv(TIMEOUT) self.assertTrue(msg.equals(rx_msg, timestamp_delta=None)) def test_recv_extended_remote(self): - self.serial.write(b"R12ABCDEF6\r") + self.serial.set_input_buffer(b"R12ABCDEF6\r") msg = self.bus.recv(TIMEOUT) self.assertIsNotNone(msg) self.assertEqual(msg.arbitration_id, 0x12ABCDEF) @@ -106,19 +167,23 @@ def test_recv_extended_remote(self): self.assertEqual(msg.dlc, 6) def test_send_extended_remote(self): + payload = b"R12ABCDEF6\r" msg = can.Message( arbitration_id=0x12ABCDEF, is_extended_id=True, is_remote_frame=True, dlc=6 ) self.bus.send(msg) + self.assertEqual(payload, self.serial.get_output_buffer()) + + self.serial.set_input_buffer(payload) rx_msg = self.bus.recv(TIMEOUT) self.assertTrue(msg.equals(rx_msg, timestamp_delta=None)) def test_partial_recv(self): - self.serial.write(b"T12ABCDEF") + self.serial.set_input_buffer(b"T12ABCDEF") msg = self.bus.recv(TIMEOUT) self.assertIsNone(msg) - self.serial.write(b"2AA55\rT12") + self.serial.set_input_buffer(b"2AA55\rT12") msg = self.bus.recv(TIMEOUT) self.assertIsNotNone(msg) self.assertEqual(msg.arbitration_id, 0x12ABCDEF) @@ -130,28 +195,21 @@ def test_partial_recv(self): msg = self.bus.recv(TIMEOUT) self.assertIsNone(msg) - self.serial.write(b"ABCDEF2AA55\r") + self.serial.set_input_buffer(b"ABCDEF2AA55\r") msg = self.bus.recv(TIMEOUT) self.assertIsNotNone(msg) def test_version(self): - self.serial.write(b"V1013\r") hw_ver, sw_ver = self.bus.get_version(0) + self.assertEqual(b"V\r", self.serial.get_output_buffer()) self.assertEqual(hw_ver, 10) self.assertEqual(sw_ver, 13) - hw_ver, sw_ver = self.bus.get_version(0) - self.assertIsNone(hw_ver) - self.assertIsNone(sw_ver) - def test_serial_number(self): - self.serial.write(b"NA123\r") sn = self.bus.get_serial_number(0) + self.assertEqual(b"N\r", self.serial.get_output_buffer()) self.assertEqual(sn, "A123") - sn = self.bus.get_serial_number(0) - self.assertIsNone(sn) - if __name__ == "__main__": unittest.main()