Skip to content

Commit

Permalink
Handle non-contiguous memoryviews in C extension.
Browse files Browse the repository at this point in the history
This avoids the special-case in Python code.
  • Loading branch information
aaugustin committed Nov 30, 2020
1 parent 965f8ec commit ecf64e7
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 93 deletions.
11 changes: 2 additions & 9 deletions src/websockets/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,8 @@ def prepare_data(data: Data) -> Tuple[int, bytes]:
"""
if isinstance(data, str):
return OP_TEXT, data.encode("utf-8")
elif isinstance(data, (bytes, bytearray)):
elif isinstance(data, (bytes, bytearray, memoryview)):
return OP_BINARY, data
elif isinstance(data, memoryview):
if data.c_contiguous:
return OP_BINARY, data
else:
return OP_BINARY, data.tobytes()
else:
raise TypeError("data must be bytes-like or str")

Expand All @@ -290,10 +285,8 @@ def prepare_ctrl(data: Data) -> bytes:
"""
if isinstance(data, str):
return data.encode("utf-8")
elif isinstance(data, (bytes, bytearray)):
elif isinstance(data, (bytes, bytearray, memoryview)):
return bytes(data)
elif isinstance(data, memoryview):
return data.tobytes()
else:
raise TypeError("data must be bytes-like or str")

Expand Down
51 changes: 26 additions & 25 deletions src/websockets/speedups.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,39 +13,35 @@ static const Py_ssize_t MASK_LEN = 4;
/* Similar to PyBytes_AsStringAndSize, but accepts more types */

static int
_PyBytesLike_AsStringAndSize(PyObject *obj, char **buffer, Py_ssize_t *length)
_PyBytesLike_AsStringAndSize(PyObject *obj, PyObject **tmp, char **buffer, Py_ssize_t *length)
{
// This supports bytes, bytearrays, and C-contiguous memoryview objects,
// which are the most useful data structures for handling byte streams.
// websockets.framing.prepare_data() returns only values of these types.
// Any object implementing the buffer protocol could be supported, however
// that would require allocation or copying memory, which is expensive.
// This supports bytes, bytearrays, and memoryview objects,
// which are common data structures for handling byte streams.
// websockets.framing.prepare_data() returns only these types.
// If *tmp isn't NULL, the caller gets a new reference.
if (PyBytes_Check(obj))
{
*tmp = NULL;
*buffer = PyBytes_AS_STRING(obj);
*length = PyBytes_GET_SIZE(obj);
}
else if (PyByteArray_Check(obj))
{
*tmp = NULL;
*buffer = PyByteArray_AS_STRING(obj);
*length = PyByteArray_GET_SIZE(obj);
}
else if (PyMemoryView_Check(obj))
{
Py_buffer *mv_buf;
mv_buf = PyMemoryView_GET_BUFFER(obj);
if (PyBuffer_IsContiguous(mv_buf, 'C'))
{
*buffer = mv_buf->buf;
*length = mv_buf->len;
}
else
*tmp = PyMemoryView_GetContiguous(obj, PyBUF_READ, 'C');
if (*tmp == NULL)
{
PyErr_Format(
PyExc_TypeError,
"expected a contiguous memoryview");
return -1;
}
Py_buffer *mv_buf;
mv_buf = PyMemoryView_GET_BUFFER(*tmp);
*buffer = mv_buf->buf;
*length = mv_buf->len;
}
else
{
Expand Down Expand Up @@ -74,15 +70,17 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds)
// A pointer to a char * + length will be extracted from the data and mask
// arguments, possibly via a Py_buffer.

PyObject *input_tmp = NULL;
char *input;
Py_ssize_t input_len;
PyObject *mask_tmp = NULL;
char *mask;
Py_ssize_t mask_len;

// Initialize a PyBytesObject then get a pointer to the underlying char *
// in order to avoid an extra memory copy in PyBytes_FromStringAndSize.

PyObject *result;
PyObject *result = NULL;
char *output;

// Other variables.
Expand All @@ -94,31 +92,31 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds)
if (!PyArg_ParseTupleAndKeywords(
args, kwds, "OO", kwlist, &input_obj, &mask_obj))
{
return NULL;
goto exit;
}

if (_PyBytesLike_AsStringAndSize(input_obj, &input, &input_len) == -1)
if (_PyBytesLike_AsStringAndSize(input_obj, &input_tmp, &input, &input_len) == -1)
{
return NULL;
goto exit;
}

if (_PyBytesLike_AsStringAndSize(mask_obj, &mask, &mask_len) == -1)
if (_PyBytesLike_AsStringAndSize(mask_obj, &mask_tmp, &mask, &mask_len) == -1)
{
return NULL;
goto exit;
}

if (mask_len != MASK_LEN)
{
PyErr_SetString(PyExc_ValueError, "mask must contain 4 bytes");
return NULL;
goto exit;
}

// Create output.

result = PyBytes_FromStringAndSize(NULL, input_len);
if (result == NULL)
{
return NULL;
goto exit;
}

// Since we juste created result, we don't need error checks.
Expand Down Expand Up @@ -172,6 +170,9 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds)
output[i] = input[i] ^ mask[i & (MASK_LEN - 1)];
}

exit:
Py_XDECREF(input_tmp);
Py_XDECREF(mask_tmp);
return result;

}
Expand Down
30 changes: 0 additions & 30 deletions tests/legacy/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,10 +580,6 @@ def test_send_binary_from_memoryview(self):
self.loop.run_until_complete(self.protocol.send(memoryview(b"tea")))
self.assertOneFrameSent(True, OP_BINARY, b"tea")

def test_send_binary_from_non_contiguous_memoryview(self):
self.loop.run_until_complete(self.protocol.send(memoryview(b"tteeaa")[::2]))
self.assertOneFrameSent(True, OP_BINARY, b"tea")

def test_send_dict(self):
with self.assertRaises(TypeError):
self.loop.run_until_complete(self.protocol.send({"not": "encoded"}))
Expand Down Expand Up @@ -624,14 +620,6 @@ def test_send_iterable_binary_from_memoryview(self):
(False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"")
)

def test_send_iterable_binary_from_non_contiguous_memoryview(self):
self.loop.run_until_complete(
self.protocol.send([memoryview(b"ttee")[::2], memoryview(b"aa")[::2]])
)
self.assertFramesSent(
(False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"")
)

def test_send_empty_iterable(self):
self.loop.run_until_complete(self.protocol.send([]))
self.assertNoFrameSent()
Expand Down Expand Up @@ -697,16 +685,6 @@ def test_send_async_iterable_binary_from_memoryview(self):
(False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"")
)

def test_send_async_iterable_binary_from_non_contiguous_memoryview(self):
self.loop.run_until_complete(
self.protocol.send(
async_iterable([memoryview(b"ttee")[::2], memoryview(b"aa")[::2]])
)
)
self.assertFramesSent(
(False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"")
)

def test_send_empty_async_iterable(self):
self.loop.run_until_complete(self.protocol.send(async_iterable([])))
self.assertNoFrameSent()
Expand Down Expand Up @@ -799,10 +777,6 @@ def test_ping_binary_from_memoryview(self):
self.loop.run_until_complete(self.protocol.ping(memoryview(b"tea")))
self.assertOneFrameSent(True, OP_PING, b"tea")

def test_ping_binary_from_non_contiguous_memoryview(self):
self.loop.run_until_complete(self.protocol.ping(memoryview(b"tteeaa")[::2]))
self.assertOneFrameSent(True, OP_PING, b"tea")

def test_ping_type_error(self):
with self.assertRaises(TypeError):
self.loop.run_until_complete(self.protocol.ping(42))
Expand Down Expand Up @@ -856,10 +830,6 @@ def test_pong_binary_from_memoryview(self):
self.loop.run_until_complete(self.protocol.pong(memoryview(b"tea")))
self.assertOneFrameSent(True, OP_PONG, b"tea")

def test_pong_binary_from_non_contiguous_memoryview(self):
self.loop.run_until_complete(self.protocol.pong(memoryview(b"tteeaa")[::2]))
self.assertOneFrameSent(True, OP_PONG, b"tea")

def test_pong_type_error(self):
with self.assertRaises(TypeError):
self.loop.run_until_complete(self.protocol.pong(42))
Expand Down
9 changes: 0 additions & 9 deletions tests/test_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,6 @@ def test_prepare_data_memoryview(self):
(OP_BINARY, memoryview(b"tea")),
)

def test_prepare_data_non_contiguous_memoryview(self):
self.assertEqual(
prepare_data(memoryview(b"tteeaa")[::2]),
(OP_BINARY, b"tea"),
)

def test_prepare_data_list(self):
with self.assertRaises(TypeError):
prepare_data([])
Expand All @@ -246,9 +240,6 @@ def test_prepare_ctrl_bytearray(self):
def test_prepare_ctrl_memoryview(self):
self.assertEqual(prepare_ctrl(memoryview(b"tea")), b"tea")

def test_prepare_ctrl_non_contiguous_memoryview(self):
self.assertEqual(prepare_ctrl(memoryview(b"tteeaa")[::2]), b"tea")

def test_prepare_ctrl_list(self):
with self.assertRaises(TypeError):
prepare_ctrl([])
Expand Down
24 changes: 4 additions & 20 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,18 @@ def test_apply_mask(self):
self.assertEqual(result, data_out)

def test_apply_mask_memoryview(self):
for data_type, mask_type in self.apply_mask_type_combos:
for mask_type in [bytes, bytearray]:
for data_in, mask, data_out in self.apply_mask_test_values:
data_in, mask = data_type(data_in), mask_type(mask)
data_in, mask = memoryview(data_in), memoryview(mask)
data_in, mask = memoryview(data_in), mask_type(mask)

with self.subTest(data_in=data_in, mask=mask):
result = self.apply_mask(data_in, mask)
self.assertEqual(result, data_out)

def test_apply_mask_non_contiguous_memoryview(self):
for data_type, mask_type in self.apply_mask_type_combos:
for mask_type in [bytes, bytearray]:
for data_in, mask, data_out in self.apply_mask_test_values:
data_in, mask = data_type(data_in), mask_type(mask)
data_in, mask = memoryview(data_in), memoryview(mask)
data_in, mask = data_in[::-1], mask[::-1]
data_in, mask = memoryview(data_in)[::-1], mask_type(mask)[::-1]
data_out = data_out[::-1]

with self.subTest(data_in=data_in, mask=mask):
Expand Down Expand Up @@ -92,16 +89,3 @@ class SpeedupsTests(ApplyMaskTests):
@staticmethod
def apply_mask(*args, **kwargs):
return c_apply_mask(*args, **kwargs)

def test_apply_mask_non_contiguous_memoryview(self):
for data_type, mask_type in self.apply_mask_type_combos:
for data_in, mask, data_out in self.apply_mask_test_values:
data_in, mask = data_type(data_in), mask_type(mask)
data_in, mask = memoryview(data_in), memoryview(mask)
data_in, mask = data_in[::-1], mask[::-1]
data_out = data_out[::-1]

with self.subTest(data_in=data_in, mask=mask):
# The C extension only supports contiguous memoryviews.
with self.assertRaises(TypeError):
self.apply_mask(data_in, mask)

0 comments on commit ecf64e7

Please sign in to comment.