Skip to content

Commit ecf64e7

Browse files
committed
Handle non-contiguous memoryviews in C extension.
This avoids the special-case in Python code.
1 parent 965f8ec commit ecf64e7

File tree

5 files changed

+32
-93
lines changed

5 files changed

+32
-93
lines changed

src/websockets/frames.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -263,13 +263,8 @@ def prepare_data(data: Data) -> Tuple[int, bytes]:
263263
"""
264264
if isinstance(data, str):
265265
return OP_TEXT, data.encode("utf-8")
266-
elif isinstance(data, (bytes, bytearray)):
266+
elif isinstance(data, (bytes, bytearray, memoryview)):
267267
return OP_BINARY, data
268-
elif isinstance(data, memoryview):
269-
if data.c_contiguous:
270-
return OP_BINARY, data
271-
else:
272-
return OP_BINARY, data.tobytes()
273268
else:
274269
raise TypeError("data must be bytes-like or str")
275270

@@ -290,10 +285,8 @@ def prepare_ctrl(data: Data) -> bytes:
290285
"""
291286
if isinstance(data, str):
292287
return data.encode("utf-8")
293-
elif isinstance(data, (bytes, bytearray)):
288+
elif isinstance(data, (bytes, bytearray, memoryview)):
294289
return bytes(data)
295-
elif isinstance(data, memoryview):
296-
return data.tobytes()
297290
else:
298291
raise TypeError("data must be bytes-like or str")
299292

src/websockets/speedups.c

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,39 +13,35 @@ static const Py_ssize_t MASK_LEN = 4;
1313
/* Similar to PyBytes_AsStringAndSize, but accepts more types */
1414

1515
static int
16-
_PyBytesLike_AsStringAndSize(PyObject *obj, char **buffer, Py_ssize_t *length)
16+
_PyBytesLike_AsStringAndSize(PyObject *obj, PyObject **tmp, char **buffer, Py_ssize_t *length)
1717
{
18-
// This supports bytes, bytearrays, and C-contiguous memoryview objects,
19-
// which are the most useful data structures for handling byte streams.
20-
// websockets.framing.prepare_data() returns only values of these types.
21-
// Any object implementing the buffer protocol could be supported, however
22-
// that would require allocation or copying memory, which is expensive.
18+
// This supports bytes, bytearrays, and memoryview objects,
19+
// which are common data structures for handling byte streams.
20+
// websockets.framing.prepare_data() returns only these types.
21+
// If *tmp isn't NULL, the caller gets a new reference.
2322
if (PyBytes_Check(obj))
2423
{
24+
*tmp = NULL;
2525
*buffer = PyBytes_AS_STRING(obj);
2626
*length = PyBytes_GET_SIZE(obj);
2727
}
2828
else if (PyByteArray_Check(obj))
2929
{
30+
*tmp = NULL;
3031
*buffer = PyByteArray_AS_STRING(obj);
3132
*length = PyByteArray_GET_SIZE(obj);
3233
}
3334
else if (PyMemoryView_Check(obj))
3435
{
35-
Py_buffer *mv_buf;
36-
mv_buf = PyMemoryView_GET_BUFFER(obj);
37-
if (PyBuffer_IsContiguous(mv_buf, 'C'))
38-
{
39-
*buffer = mv_buf->buf;
40-
*length = mv_buf->len;
41-
}
42-
else
36+
*tmp = PyMemoryView_GetContiguous(obj, PyBUF_READ, 'C');
37+
if (*tmp == NULL)
4338
{
44-
PyErr_Format(
45-
PyExc_TypeError,
46-
"expected a contiguous memoryview");
4739
return -1;
4840
}
41+
Py_buffer *mv_buf;
42+
mv_buf = PyMemoryView_GET_BUFFER(*tmp);
43+
*buffer = mv_buf->buf;
44+
*length = mv_buf->len;
4945
}
5046
else
5147
{
@@ -74,15 +70,17 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds)
7470
// A pointer to a char * + length will be extracted from the data and mask
7571
// arguments, possibly via a Py_buffer.
7672

73+
PyObject *input_tmp = NULL;
7774
char *input;
7875
Py_ssize_t input_len;
76+
PyObject *mask_tmp = NULL;
7977
char *mask;
8078
Py_ssize_t mask_len;
8179

8280
// Initialize a PyBytesObject then get a pointer to the underlying char *
8381
// in order to avoid an extra memory copy in PyBytes_FromStringAndSize.
8482

85-
PyObject *result;
83+
PyObject *result = NULL;
8684
char *output;
8785

8886
// Other variables.
@@ -94,31 +92,31 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds)
9492
if (!PyArg_ParseTupleAndKeywords(
9593
args, kwds, "OO", kwlist, &input_obj, &mask_obj))
9694
{
97-
return NULL;
95+
goto exit;
9896
}
9997

100-
if (_PyBytesLike_AsStringAndSize(input_obj, &input, &input_len) == -1)
98+
if (_PyBytesLike_AsStringAndSize(input_obj, &input_tmp, &input, &input_len) == -1)
10199
{
102-
return NULL;
100+
goto exit;
103101
}
104102

105-
if (_PyBytesLike_AsStringAndSize(mask_obj, &mask, &mask_len) == -1)
103+
if (_PyBytesLike_AsStringAndSize(mask_obj, &mask_tmp, &mask, &mask_len) == -1)
106104
{
107-
return NULL;
105+
goto exit;
108106
}
109107

110108
if (mask_len != MASK_LEN)
111109
{
112110
PyErr_SetString(PyExc_ValueError, "mask must contain 4 bytes");
113-
return NULL;
111+
goto exit;
114112
}
115113

116114
// Create output.
117115

118116
result = PyBytes_FromStringAndSize(NULL, input_len);
119117
if (result == NULL)
120118
{
121-
return NULL;
119+
goto exit;
122120
}
123121

124122
// Since we juste created result, we don't need error checks.
@@ -172,6 +170,9 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds)
172170
output[i] = input[i] ^ mask[i & (MASK_LEN - 1)];
173171
}
174172

173+
exit:
174+
Py_XDECREF(input_tmp);
175+
Py_XDECREF(mask_tmp);
175176
return result;
176177

177178
}

tests/legacy/test_protocol.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -580,10 +580,6 @@ def test_send_binary_from_memoryview(self):
580580
self.loop.run_until_complete(self.protocol.send(memoryview(b"tea")))
581581
self.assertOneFrameSent(True, OP_BINARY, b"tea")
582582

583-
def test_send_binary_from_non_contiguous_memoryview(self):
584-
self.loop.run_until_complete(self.protocol.send(memoryview(b"tteeaa")[::2]))
585-
self.assertOneFrameSent(True, OP_BINARY, b"tea")
586-
587583
def test_send_dict(self):
588584
with self.assertRaises(TypeError):
589585
self.loop.run_until_complete(self.protocol.send({"not": "encoded"}))
@@ -624,14 +620,6 @@ def test_send_iterable_binary_from_memoryview(self):
624620
(False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"")
625621
)
626622

627-
def test_send_iterable_binary_from_non_contiguous_memoryview(self):
628-
self.loop.run_until_complete(
629-
self.protocol.send([memoryview(b"ttee")[::2], memoryview(b"aa")[::2]])
630-
)
631-
self.assertFramesSent(
632-
(False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"")
633-
)
634-
635623
def test_send_empty_iterable(self):
636624
self.loop.run_until_complete(self.protocol.send([]))
637625
self.assertNoFrameSent()
@@ -697,16 +685,6 @@ def test_send_async_iterable_binary_from_memoryview(self):
697685
(False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"")
698686
)
699687

700-
def test_send_async_iterable_binary_from_non_contiguous_memoryview(self):
701-
self.loop.run_until_complete(
702-
self.protocol.send(
703-
async_iterable([memoryview(b"ttee")[::2], memoryview(b"aa")[::2]])
704-
)
705-
)
706-
self.assertFramesSent(
707-
(False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"")
708-
)
709-
710688
def test_send_empty_async_iterable(self):
711689
self.loop.run_until_complete(self.protocol.send(async_iterable([])))
712690
self.assertNoFrameSent()
@@ -799,10 +777,6 @@ def test_ping_binary_from_memoryview(self):
799777
self.loop.run_until_complete(self.protocol.ping(memoryview(b"tea")))
800778
self.assertOneFrameSent(True, OP_PING, b"tea")
801779

802-
def test_ping_binary_from_non_contiguous_memoryview(self):
803-
self.loop.run_until_complete(self.protocol.ping(memoryview(b"tteeaa")[::2]))
804-
self.assertOneFrameSent(True, OP_PING, b"tea")
805-
806780
def test_ping_type_error(self):
807781
with self.assertRaises(TypeError):
808782
self.loop.run_until_complete(self.protocol.ping(42))
@@ -856,10 +830,6 @@ def test_pong_binary_from_memoryview(self):
856830
self.loop.run_until_complete(self.protocol.pong(memoryview(b"tea")))
857831
self.assertOneFrameSent(True, OP_PONG, b"tea")
858832

859-
def test_pong_binary_from_non_contiguous_memoryview(self):
860-
self.loop.run_until_complete(self.protocol.pong(memoryview(b"tteeaa")[::2]))
861-
self.assertOneFrameSent(True, OP_PONG, b"tea")
862-
863833
def test_pong_type_error(self):
864834
with self.assertRaises(TypeError):
865835
self.loop.run_until_complete(self.protocol.pong(42))

tests/test_frames.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -218,12 +218,6 @@ def test_prepare_data_memoryview(self):
218218
(OP_BINARY, memoryview(b"tea")),
219219
)
220220

221-
def test_prepare_data_non_contiguous_memoryview(self):
222-
self.assertEqual(
223-
prepare_data(memoryview(b"tteeaa")[::2]),
224-
(OP_BINARY, b"tea"),
225-
)
226-
227221
def test_prepare_data_list(self):
228222
with self.assertRaises(TypeError):
229223
prepare_data([])
@@ -246,9 +240,6 @@ def test_prepare_ctrl_bytearray(self):
246240
def test_prepare_ctrl_memoryview(self):
247241
self.assertEqual(prepare_ctrl(memoryview(b"tea")), b"tea")
248242

249-
def test_prepare_ctrl_non_contiguous_memoryview(self):
250-
self.assertEqual(prepare_ctrl(memoryview(b"tteeaa")[::2]), b"tea")
251-
252243
def test_prepare_ctrl_list(self):
253244
with self.assertRaises(TypeError):
254245
prepare_ctrl([])

tests/test_utils.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,21 +43,18 @@ def test_apply_mask(self):
4343
self.assertEqual(result, data_out)
4444

4545
def test_apply_mask_memoryview(self):
46-
for data_type, mask_type in self.apply_mask_type_combos:
46+
for mask_type in [bytes, bytearray]:
4747
for data_in, mask, data_out in self.apply_mask_test_values:
48-
data_in, mask = data_type(data_in), mask_type(mask)
49-
data_in, mask = memoryview(data_in), memoryview(mask)
48+
data_in, mask = memoryview(data_in), mask_type(mask)
5049

5150
with self.subTest(data_in=data_in, mask=mask):
5251
result = self.apply_mask(data_in, mask)
5352
self.assertEqual(result, data_out)
5453

5554
def test_apply_mask_non_contiguous_memoryview(self):
56-
for data_type, mask_type in self.apply_mask_type_combos:
55+
for mask_type in [bytes, bytearray]:
5756
for data_in, mask, data_out in self.apply_mask_test_values:
58-
data_in, mask = data_type(data_in), mask_type(mask)
59-
data_in, mask = memoryview(data_in), memoryview(mask)
60-
data_in, mask = data_in[::-1], mask[::-1]
57+
data_in, mask = memoryview(data_in)[::-1], mask_type(mask)[::-1]
6158
data_out = data_out[::-1]
6259

6360
with self.subTest(data_in=data_in, mask=mask):
@@ -92,16 +89,3 @@ class SpeedupsTests(ApplyMaskTests):
9289
@staticmethod
9390
def apply_mask(*args, **kwargs):
9491
return c_apply_mask(*args, **kwargs)
95-
96-
def test_apply_mask_non_contiguous_memoryview(self):
97-
for data_type, mask_type in self.apply_mask_type_combos:
98-
for data_in, mask, data_out in self.apply_mask_test_values:
99-
data_in, mask = data_type(data_in), mask_type(mask)
100-
data_in, mask = memoryview(data_in), memoryview(mask)
101-
data_in, mask = data_in[::-1], mask[::-1]
102-
data_out = data_out[::-1]
103-
104-
with self.subTest(data_in=data_in, mask=mask):
105-
# The C extension only supports contiguous memoryviews.
106-
with self.assertRaises(TypeError):
107-
self.apply_mask(data_in, mask)

0 commit comments

Comments
 (0)