diff --git a/nibe/connection/nibegw.py b/nibe/connection/nibegw.py index 3edb30a..dfee5af 100644 --- a/nibe/connection/nibegw.py +++ b/nibe/connection/nibegw.py @@ -5,13 +5,14 @@ from dataclasses import dataclass import errno from functools import reduce +import io from io import BytesIO from ipaddress import ip_address import logging from operator import xor import socket import struct -from typing import Dict, Literal, Optional, Union +from typing import Any, Dict, Literal, Optional, Union from construct import ( Adapter, @@ -183,42 +184,11 @@ def connection_made(self, transport): def datagram_received(self, data: bytes, addr): """Callback when data is received.""" logger.debug(f"Received {hexlify(data).decode('utf-8')} from {addr}") + try: - msg = Response.parse(data) - - if not self._remote_ip: - logger.debug("Pump discovered at %s", addr) - self._remote_ip = addr[0] - - self.status = ConnectionStatus.CONNECTED - - logger.debug(msg.fields.value) - cmd = msg.fields.value.cmd - if cmd == "MODBUS_DATA_MSG": - data: dict[int, bytes] = { - row.coil_address: row.value - for row in msg.fields.value.data - if row.coil_address != 0xFFFF - } - self._on_raw_coil_set(data) - elif cmd == "MODBUS_READ_RESP": - row = msg.fields.value.data - self._on_raw_coil_value(row.coil_address, row.value) - elif cmd == "MODBUS_WRITE_RESP": - with suppress(InvalidStateError, CancelledError, KeyError): - self._futures["write"].set_result(msg.fields.value.data.result) - elif cmd == "RMU_DATA_MSG": - self._on_rmu_data(msg.fields.value) - elif cmd == "PRODUCT_INFO_MSG": - data = msg.fields.value.data - product_info = ProductInfo(data["model"], data["version"]) - with suppress(InvalidStateError, CancelledError, KeyError): - self._futures["product_info"].set_result(product_info) - self.notify_event_listeners( - self.PRODUCT_INFO_EVENT, product_info=product_info - ) - elif not isinstance(cmd, EnumIntegerString): - logger.debug(f"Unknown command {cmd}") + with io.BytesIO(bytes(data)) as stream: + while block := Block.parse_stream(stream): + self._on_block(block, addr) except ConstructError as e: logger.warning( f"Ignoring packet from {addr} due to parse error: {hexlify(data).decode('utf-8')}: {e}" @@ -230,6 +200,47 @@ def datagram_received(self, data: bytes, addr): f"Unexpected exception during parsing packet data '{hexlify(data).decode('utf-8')}' from {addr}" ) + def _on_block(self, block: Container[Any], addr) -> None: + if block.start_byte == "RESPONSE": + self._on_response(block, addr) + else: + logger.debug(block) + + def _on_response(self, msg: Container[Any], addr) -> None: + if not self._remote_ip: + logger.debug("Pump discovered at %s", addr) + self._remote_ip = addr[0] + + self.status = ConnectionStatus.CONNECTED + + logger.debug(msg.fields.value) + cmd = msg.fields.value.cmd + if cmd == "MODBUS_DATA_MSG": + data: dict[int, bytes] = { + row.coil_address: row.value + for row in msg.fields.value.data + if row.coil_address != 0xFFFF + } + self._on_raw_coil_set(data) + elif cmd == "MODBUS_READ_RESP": + row = msg.fields.value.data + self._on_raw_coil_value(row.coil_address, row.value) + elif cmd == "MODBUS_WRITE_RESP": + with suppress(InvalidStateError, CancelledError, KeyError): + self._futures["write"].set_result(msg.fields.value.data.result) + elif cmd == "RMU_DATA_MSG": + self._on_rmu_data(msg.fields.value) + elif cmd == "PRODUCT_INFO_MSG": + data = msg.fields.value.data + product_info = ProductInfo(data["model"], data["version"]) + with suppress(InvalidStateError, CancelledError, KeyError): + self._futures["product_info"].set_result(product_info) + self.notify_event_listeners( + self.PRODUCT_INFO_EVENT, product_info=product_info + ) + elif not isinstance(cmd, EnumIntegerString): + logger.debug(f"Unknown command {cmd}") + async def read_product_info( self, timeout: float = READ_PRODUCT_INFO_TIMEOUT ) -> ProductInfo: @@ -859,9 +870,9 @@ def _encode(self, obj, context, path): Block = FocusedSeq( "data", - "start" / Peek(StartCode), + "start_byte" / Peek(StartCode), "data" / Switch( - this.start, + this.start_byte, BlockTypes ), ) diff --git a/tests/connection/test_nibegw.py b/tests/connection/test_nibegw.py index 9fc01a9..dc35245 100644 --- a/tests/connection/test_nibegw.py +++ b/tests/connection/test_nibegw.py @@ -151,6 +151,20 @@ async def test_read_product_info(nibegw: NibeGW): assert "F1255-12 R" == product.model +async def test_read_product_info_with_extras(nibegw: NibeGW): + _enqueue_datagram( + nibegw, + "5c0019ee00f7" # token accessory version + "c0ee03ee0101c3" # accessory version from accessory + "06" # ack from pump + "5c00206d0d0124e346313235352d313220529f", + ) + product = await nibegw.read_product_info() + + assert isinstance(product, ProductInfo) + assert "F1255-12 R" == product.model + + @pytest.mark.parametrize( ("raw", "table_processing_mode", "calls"), [