Skip to content

Commit 7682421

Browse files
authored
Merge pull request #218 from Neradoc/more-compatible-api
Bring more compatibility with native sockets
2 parents 3fcea23 + 426af8b commit 7682421

File tree

2 files changed

+87
-18
lines changed

2 files changed

+87
-18
lines changed

adafruit_esp32spi/adafruit_esp32spi.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -825,7 +825,8 @@ def socket_connected(self, socket_num):
825825
return self.socket_status(socket_num) == SOCKET_ESTABLISHED
826826

827827
def socket_write(self, socket_num, buffer, conn_mode=TCP_MODE):
828-
"""Write the bytearray buffer to a socket"""
828+
"""Write the bytearray buffer to a socket.
829+
Returns the number of bytes written"""
829830
if self._debug:
830831
print("Writing:", buffer)
831832
self._socknum_ll[0][0] = socket_num
@@ -853,7 +854,7 @@ def socket_write(self, socket_num, buffer, conn_mode=TCP_MODE):
853854
resp = self._send_command_get_response(_SEND_UDP_DATA_CMD, self._socknum_ll)
854855
if resp[0][0] != 1:
855856
raise ConnectionError("Failed to send UDP data")
856-
return
857+
return sent
857858

858859
if sent != len(buffer):
859860
self.socket_close(socket_num)
@@ -863,6 +864,8 @@ def socket_write(self, socket_num, buffer, conn_mode=TCP_MODE):
863864
if resp[0][0] != 1:
864865
raise ConnectionError("Failed to verify data sent")
865866

867+
return sent
868+
866869
def socket_available(self, socket_num):
867870
"""Determine how many bytes are waiting to be read on the socket"""
868871
self._socknum_ll[0][0] = socket_num

adafruit_esp32spi/adafruit_esp32spi_socketpool.py

Lines changed: 82 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import annotations
1515

1616
try:
17-
from typing import TYPE_CHECKING, Optional
17+
from typing import TYPE_CHECKING, Optional, Tuple
1818

1919
if TYPE_CHECKING:
2020
from esp32spi.adafruit_esp32spi import ESP_SPIcontrol # noqa: UP007
@@ -36,11 +36,15 @@
3636
class SocketPool:
3737
"""ESP32SPI SocketPool library"""
3838

39-
SOCK_STREAM = const(0)
40-
SOCK_DGRAM = const(1)
39+
# socketpool constants
40+
SOCK_STREAM = const(1)
41+
SOCK_DGRAM = const(2)
4142
AF_INET = const(2)
42-
NO_SOCKET_AVAIL = const(255)
43+
SOL_SOCKET = const(0xFFF)
44+
SO_REUSEADDR = const(0x0004)
4345

46+
# implementation specific constants
47+
NO_SOCKET_AVAIL = const(255)
4448
MAX_PACKET = const(4000)
4549

4650
def __new__(cls, iface: ESP_SPIcontrol):
@@ -73,7 +77,13 @@ def socket(
7377

7478
class Socket:
7579
"""A simplified implementation of the Python 'socket' class, for connecting
76-
through an interface to a remote device"""
80+
through an interface to a remote device. Has properties specific to the
81+
implementation.
82+
83+
:param SocketPool socket_pool: The underlying socket pool.
84+
:param Optional[int] socknum: Allows wrapping a Socket instance around a socket
85+
number returned by the nina firmware. Used internally.
86+
"""
7787

7888
def __init__(
7989
self,
@@ -82,15 +92,17 @@ def __init__(
8292
type: int = SocketPool.SOCK_STREAM,
8393
proto: int = 0,
8494
fileno: Optional[int] = None, # noqa: UP007
95+
socknum: Optional[int] = None, # noqa: UP007
8596
):
8697
if family != SocketPool.AF_INET:
8798
raise ValueError("Only AF_INET family supported")
8899
self._socket_pool = socket_pool
89100
self._interface = self._socket_pool._interface
90101
self._type = type
91102
self._buffer = b""
92-
self._socknum = self._interface.get_socket()
93-
self.settimeout(0)
103+
self._socknum = socknum if socknum is not None else self._interface.get_socket()
104+
self._bound = ()
105+
self.settimeout(None)
94106

95107
def __enter__(self):
96108
return self
@@ -121,13 +133,14 @@ def send(self, data):
121133
conntype = self._interface.UDP_MODE
122134
else:
123135
conntype = self._interface.TCP_MODE
124-
self._interface.socket_write(self._socknum, data, conn_mode=conntype)
136+
sent = self._interface.socket_write(self._socknum, data, conn_mode=conntype)
125137
gc.collect()
138+
return sent
126139

127140
def sendto(self, data, address):
128141
"""Connect and send some data to the socket."""
129142
self.connect(address)
130-
self.send(data)
143+
return self.send(data)
131144

132145
def recv(self, bufsize: int) -> bytes:
133146
"""Reads some bytes from the connected remote address. Will only return
@@ -150,12 +163,12 @@ def recv_into(self, buffer, nbytes: int = 0):
150163
if not 0 <= nbytes <= len(buffer):
151164
raise ValueError("nbytes must be 0 to len(buffer)")
152165

153-
last_read_time = time.monotonic()
166+
last_read_time = time.monotonic_ns()
154167
num_to_read = len(buffer) if nbytes == 0 else nbytes
155168
num_read = 0
156169
while num_to_read > 0:
157170
# we might have read socket data into the self._buffer with:
158-
# esp32spi_wsgiserver: socket_readline
171+
# adafruit_wsgi.esp32spi_wsgiserver: socket_readline
159172
if len(self._buffer) > 0:
160173
bytes_to_read = min(num_to_read, len(self._buffer))
161174
buffer[num_read : num_read + bytes_to_read] = self._buffer[:bytes_to_read]
@@ -167,7 +180,7 @@ def recv_into(self, buffer, nbytes: int = 0):
167180

168181
num_avail = self._available()
169182
if num_avail > 0:
170-
last_read_time = time.monotonic()
183+
last_read_time = time.monotonic_ns()
171184
bytes_read = self._interface.socket_read(self._socknum, min(num_to_read, num_avail))
172185
buffer[num_read : num_read + len(bytes_read)] = bytes_read
173186
num_read += len(bytes_read)
@@ -176,15 +189,27 @@ def recv_into(self, buffer, nbytes: int = 0):
176189
# We got a message, but there are no more bytes to read, so we can stop.
177190
break
178191
# No bytes yet, or more bytes requested.
179-
if self._timeout > 0 and time.monotonic() - last_read_time > self._timeout:
192+
193+
if self._timeout == 0: # if in non-blocking mode, stop now.
194+
break
195+
196+
# Time out if there's a positive timeout set.
197+
delta = (time.monotonic_ns() - last_read_time) // 1_000_000
198+
if self._timeout > 0 and delta > self._timeout:
180199
raise OSError(errno.ETIMEDOUT)
181200
return num_read
182201

183202
def settimeout(self, value):
184-
"""Set the read timeout for sockets.
185-
If value is 0 socket reads will block until a message is available.
203+
"""Set the read timeout for sockets in seconds.
204+
``0`` means non-blocking. ``None`` means block indefinitely.
186205
"""
187-
self._timeout = value
206+
if value is None:
207+
self._timeout = -1
208+
else:
209+
if value < 0:
210+
raise ValueError("Timeout cannot be a negative number")
211+
# internally in milliseconds as an int
212+
self._timeout = int(value * 1000)
188213

189214
def _available(self):
190215
"""Returns how many bytes of data are available to be read (up to the MAX_PACKET length)"""
@@ -217,3 +242,44 @@ def _connected(self):
217242
def close(self):
218243
"""Close the socket, after reading whatever remains"""
219244
self._interface.socket_close(self._socknum)
245+
246+
def accept(self):
247+
"""Accept a connection on a listening socket of type SOCK_STREAM,
248+
creating a new socket of type SOCK_STREAM. Returns a tuple of
249+
(new_socket, remote_address)
250+
"""
251+
client_sock_num = self._interface.socket_available(self._socknum)
252+
if client_sock_num != SocketPool.NO_SOCKET_AVAIL:
253+
sock = Socket(self._socket_pool, socknum=client_sock_num)
254+
# get remote information (addr and port)
255+
remote = self._interface.get_remote_data(client_sock_num)
256+
ip_address = "{}.{}.{}.{}".format(*remote["ip_addr"])
257+
port = remote["port"]
258+
client_address = (ip_address, port)
259+
return sock, client_address
260+
raise OSError(errno.ECONNRESET)
261+
262+
def bind(self, address: tuple[str, int]):
263+
"""Bind a socket to an address"""
264+
self._bound = address
265+
266+
def listen(self, backlog: int): # pylint: disable=unused-argument
267+
"""Set socket to listen for incoming connections.
268+
:param int backlog: length of backlog queue for waiting connections (ignored)
269+
"""
270+
if not self._bound:
271+
self._bound = (self._interface.ip_address, 80)
272+
port = self._bound[1]
273+
self._interface.start_server(port, self._socknum)
274+
275+
def setblocking(self, flag: bool):
276+
"""Set the blocking behaviour of this socket.
277+
:param bool flag: False means non-blocking, True means block indefinitely.
278+
"""
279+
if flag:
280+
self.settimeout(None)
281+
else:
282+
self.settimeout(0)
283+
284+
def setsockopt(self, *opts, **kwopts):
285+
"""Dummy call for compatibility."""

0 commit comments

Comments
 (0)