Skip to content

Commit

Permalink
Implement zerocopy writes (#990)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Nov 1, 2024
1 parent 4bea46b commit ba05d38
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 54 deletions.
2 changes: 1 addition & 1 deletion aioesphomeapi/_frame_helper/base.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ cdef class APIFrameHelper:
cdef object _loop
cdef APIConnection _connection
cdef object _transport
cdef public object _writer
cdef public object _writelines
cdef public object ready_future
cdef bytes _buffer
cdef unsigned int _buffer_len
Expand Down
21 changes: 13 additions & 8 deletions aioesphomeapi/_frame_helper/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from abc import abstractmethod
import asyncio
from collections.abc import Iterable
import logging
from typing import TYPE_CHECKING, Callable, cast

Expand Down Expand Up @@ -31,7 +32,7 @@ class APIFrameHelper:
"_loop",
"_connection",
"_transport",
"_writer",
"_writelines",
"ready_future",
"_buffer",
"_buffer_len",
Expand All @@ -51,7 +52,9 @@ def __init__(
self._loop = loop
self._connection = connection
self._transport: asyncio.Transport | None = None
self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None
self._writelines: (
None | (Callable[[Iterable[bytes | bytearray | memoryview[int]]], None])
) = None
self.ready_future = self._loop.create_future()
self._buffer: bytes | None = None
self._buffer_len = 0
Expand Down Expand Up @@ -146,7 +149,7 @@ def write_packets(
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Handle a new connection."""
self._transport = cast(asyncio.Transport, transport)
self._writer = self._transport.write
self._writelines = self._transport.writelines

def _handle_error_and_close(self, exc: Exception) -> None:
self._handle_error(exc)
Expand All @@ -172,20 +175,22 @@ def close(self) -> None:
if self._transport:
self._transport.close()
self._transport = None
self._writer = None
self._writelines = None

def pause_writing(self) -> None:
"""Stub."""

def resume_writing(self) -> None:
"""Stub."""

def _write_bytes(self, data: _bytes, debug_enabled: bool) -> None:
def _write_bytes(self, data: Iterable[_bytes], debug_enabled: bool) -> None:
"""Write bytes to the socket."""
if debug_enabled:
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, data.hex())
_LOGGER.debug(
"%s: Sending frame: [%s]", self._log_name, b"".join(data).hex()
)

if TYPE_CHECKING:
assert self._writer is not None, "Writer is not set"
assert self._writelines is not None, "Writer is not set"

self._writer(data)
self._writelines(data)
4 changes: 2 additions & 2 deletions aioesphomeapi/_frame_helper/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def _send_hello_handshake(self) -> None:
frame_len = len(handshake_frame) + 1
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
self._write_bytes(
b"".join((NOISE_HELLO, header, b"\x00", handshake_frame)),
(NOISE_HELLO, header, b"\x00", handshake_frame),
_LOGGER.isEnabledFor(logging.DEBUG),
)

Expand Down Expand Up @@ -346,7 +346,7 @@ def write_packets(
out.append(header)
out.append(frame)

self._write_bytes(b"".join(out), debug_enabled)
self._write_bytes(out, debug_enabled)

def _handle_frame(self, frame: bytes) -> None:
"""Handle an incoming frame."""
Expand Down
5 changes: 3 additions & 2 deletions aioesphomeapi/_frame_helper/plain_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,10 @@ def write_packets(
out.append(b"\0")
out.append(varuint_to_bytes(len(data)))
out.append(varuint_to_bytes(type_))
out.append(data)
if data:
out.append(data)

self._write_bytes(b"".join(out), debug_enabled)
self._write_bytes(out, debug_enabled)

def data_received(self, data: bytes | bytearray | memoryview) -> None:
self._add_to_buffer(data)
Expand Down
18 changes: 11 additions & 7 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,19 @@ class Estr(str):
"""A subclassed string."""


def generate_plaintext_packet(msg: message.Message) -> bytes:
def generate_split_plaintext_packet(msg: message.Message) -> list[bytes]:
type_ = PROTO_TO_MESSAGE_TYPE[msg.__class__]
bytes_ = msg.SerializeToString()
return (
b"\0"
+ _cached_varuint_to_bytes(len(bytes_))
+ _cached_varuint_to_bytes(type_)
+ bytes_
)
return [
b"\0",
_cached_varuint_to_bytes(len(bytes_)),
_cached_varuint_to_bytes(type_),
bytes_,
]


def generate_plaintext_packet(msg: message.Message) -> bytes:
return b"".join(generate_split_plaintext_packet(msg))


def as_utc(dattim: datetime) -> datetime:
Expand Down
23 changes: 12 additions & 11 deletions tests/test__frame_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import base64
from collections.abc import Iterable
import sys
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
Expand Down Expand Up @@ -132,7 +133,7 @@ def __init__(self, *args: Any, writer: Any | None = None, **kwargs: Any) -> None
"""Swallow args."""
super().__init__(*args, **kwargs)
transport = MagicMock()
transport.write = writer or MagicMock()
transport.writelines = writer or MagicMock()
self.__transport = transport
self.connection_made(transport)

Expand All @@ -147,7 +148,7 @@ def mock_write_frame(self, frame: bytes) -> None:
frame_len = len(frame)
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
try:
self._writer(header + frame)
self._writelines([header, frame])
except (RuntimeError, ConnectionResetError, OSError) as err:
raise SocketClosedAPIError(
f"{self._log_name}: Error while writing data: {err}"
Expand Down Expand Up @@ -437,8 +438,8 @@ async def test_noise_frame_helper_handshake_failure():
psk_bytes = base64.b64decode(noise_psk)
writes = []

def _writer(data: bytes):
writes.append(data)
def _writelines(data: Iterable[bytes]):
writes.append(b"".join(data))

connection, _ = _make_mock_connection()

Expand All @@ -448,7 +449,7 @@ def _writer(data: bytes):
expected_name="servicetest",
client_info="my client",
log_name="test",
writer=_writer,
writer=_writelines,
)

proto = _mock_responder_proto(psk_bytes)
Expand Down Expand Up @@ -486,8 +487,8 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
psk_bytes = base64.b64decode(noise_psk)
writes = []

def _writer(data: bytes):
writes.append(data)
def _writelines(data: Iterable[bytes]):
writes.append(b"".join(data))

connection, packets = _make_mock_connection()

Expand All @@ -497,7 +498,7 @@ def _writer(data: bytes):
expected_name="servicetest",
client_info="my client",
log_name="test",
writer=_writer,
writer=_writelines,
)

proto = _mock_responder_proto(psk_bytes)
Expand Down Expand Up @@ -548,8 +549,8 @@ async def test_noise_frame_helper_bad_encryption(
psk_bytes = base64.b64decode(noise_psk)
writes = []

def _writer(data: bytes):
writes.append(data)
def _writelines(data: Iterable[bytes]):
writes.append(b"".join(data))

connection, packets = _make_mock_connection()

Expand All @@ -559,7 +560,7 @@ def _writer(data: bytes):
expected_name="servicetest",
client_info="my client",
log_name="test",
writer=_writer,
writer=_writelines,
)

proto = _mock_responder_proto(psk_bytes)
Expand Down
25 changes: 18 additions & 7 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@
from .common import (
Estr,
generate_plaintext_packet,
generate_split_plaintext_packet,
get_mock_zeroconf,
mock_data_received,
)
Expand Down Expand Up @@ -1439,7 +1440,12 @@ async def test_bluetooth_gatt_write_without_response(
)
await asyncio.sleep(0)
await write_task
assert transport.mock_calls[0][1][0] == b'\x00\x0cK\x08\xd2\t\x10\xd2\t"\x041234'
assert transport.mock_calls[0][1][0] == [
b"\x00",
b"\x0c",
b"K",
b'\x08\xd2\t\x10\xd2\t"\x041234',
]

with pytest.raises(TimeoutAPIError, match="BluetoothGATTWriteResponse"):
await client.bluetooth_gatt_write(1234, 1234, b"1234", True, timeout=0)
Expand Down Expand Up @@ -1484,7 +1490,12 @@ async def test_bluetooth_gatt_write_descriptor_without_response(
)
await asyncio.sleep(0)
await write_task
assert transport.mock_calls[0][1][0] == b"\x00\x0cM\x08\xd2\t\x10\xd2\t\x1a\x041234"
assert transport.mock_calls[0][1][0] == [
b"\x00",
b"\x0c",
b"M",
b"\x08\xd2\t\x10\xd2\t\x1a\x041234",
]

with pytest.raises(TimeoutAPIError, match="BluetoothGATTWriteResponse"):
await client.bluetooth_gatt_write_descriptor(1234, 1234, b"1234", timeout=0)
Expand Down Expand Up @@ -2042,8 +2053,8 @@ def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None

cancel = await connect_task
assert states == [(True, 23, 0)]
transport.write.assert_called_once_with(
generate_plaintext_packet(
transport.writelines.assert_called_once_with(
generate_split_plaintext_packet(
BluetoothDeviceRequest(
address=1234,
request_type=method,
Expand Down Expand Up @@ -2133,13 +2144,13 @@ def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None
)
await asyncio.sleep(0)
# The connect request should be written
assert len(transport.write.mock_calls) == 1
assert len(transport.writelines.mock_calls) == 1
await asyncio.sleep(0)
await asyncio.sleep(0)
await asyncio.sleep(0)
# Now that we timed out, the disconnect
# request should be written
assert len(transport.write.mock_calls) == 2
assert len(transport.writelines.mock_calls) == 2
response: message.Message = BluetoothDeviceConnectionResponse(
address=1234, connected=False, mtu=23, error=8
)
Expand Down Expand Up @@ -2177,7 +2188,7 @@ def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None
)
await asyncio.sleep(0)
# The connect request should be written
assert len(transport.write.mock_calls) == 1
assert len(transport.writelines.mock_calls) == 1
connect_task.cancel()
with pytest.raises(asyncio.CancelledError):
await connect_task
Expand Down
Loading

0 comments on commit ba05d38

Please sign in to comment.