Skip to content

Commit

Permalink
Avoid making a copy of large frames.
Browse files Browse the repository at this point in the history
This isn't very significant compared to the cost of compression.

It can make a real difference for decompression.
  • Loading branch information
aaugustin committed Sep 27, 2024
1 parent baadc33 commit 524dd4a
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 14 deletions.
14 changes: 14 additions & 0 deletions docs/project/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,20 @@ Backwards-incompatible changes
Aliases for deprecated API were removed from ``__all__``. As a consequence,
they cannot be imported e.g. with ``from websockets import *`` anymore.

.. admonition:: :attr:`Frame.data <frames.Frame.data>` is now a bytes-like object.
:class: note

In addition to :class:`bytes`, it may be a :class:`bytearray` or a
:class:`memoryview`.

If you wrote an :class:`extension <extensions.Extension>` that relies on
methods not provided by these new types, you may need to update your code.

Improvements
............

* Sending or receiving large compressed frames is now faster.

.. _13.1:

13.1
Expand Down
30 changes: 20 additions & 10 deletions src/websockets/extensions/permessage_deflate.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,16 +129,22 @@ def decode(
# Uncompress data. Protect against zip bombs by preventing zlib from
# decompressing more than max_length bytes (except when the limit is
# disabled with max_size = None).
data = frame.data
if frame.fin:
data += _EMPTY_UNCOMPRESSED_BLOCK
if frame.fin and len(frame.data) < 2044:
# Profiling shows that appending four bytes, which makes a copy, is
# faster than calling decompress() again when data is less than 2kB.
data = bytes(frame.data) + _EMPTY_UNCOMPRESSED_BLOCK
else:
data = frame.data
max_length = 0 if max_size is None else max_size
try:
data = self.decoder.decompress(data, max_length)
if self.decoder.unconsumed_tail:
raise PayloadTooBig(f"over size limit (? > {max_size} bytes)")
if frame.fin and len(frame.data) >= 2044:
# This cannot generate additional data.
self.decoder.decompress(_EMPTY_UNCOMPRESSED_BLOCK)
except zlib.error as exc:
raise ProtocolError("decompression failed") from exc
if self.decoder.unconsumed_tail:
raise PayloadTooBig(f"over size limit (? > {max_size} bytes)")

# Allow garbage collection of the decoder if it won't be reused.
if frame.fin and self.remote_no_context_takeover:
Expand Down Expand Up @@ -176,11 +182,15 @@ def encode(self, frame: frames.Frame) -> frames.Frame:

# Compress data.
data = self.encoder.compress(frame.data) + self.encoder.flush(zlib.Z_SYNC_FLUSH)
if frame.fin and data[-4:] == _EMPTY_UNCOMPRESSED_BLOCK:
# Making a copy is faster than memoryview(a)[:-4] until about 2kB.
# On larger messages, it's slower but profiling shows that it's
# marginal compared to compress() and flush(). Keep it simple.
data = data[:-4]
if frame.fin:
# Sync flush generates between 5 or 6 bytes, ending with the bytes
# 0x00 0x00 0xff 0xff, which must be removed.
assert data[-4:] == _EMPTY_UNCOMPRESSED_BLOCK
# Making a copy is faster than memoryview(a)[:-4] until 2kB.
if len(data) < 2048:
data = data[:-4]
else:
data = memoryview(data)[:-4]

# Allow garbage collection of the encoder if it won't be reused.
if frame.fin and self.local_no_context_takeover:
Expand Down
8 changes: 4 additions & 4 deletions src/websockets/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import secrets
import struct
from collections.abc import Generator, Sequence
from typing import Callable
from typing import Callable, Union

from .exceptions import PayloadTooBig, ProtocolError

Expand Down Expand Up @@ -139,7 +139,7 @@ class Frame:
"""

opcode: Opcode
data: bytes
data: Union[bytes, bytearray, memoryview]
fin: bool = True
rsv1: bool = False
rsv2: bool = False
Expand All @@ -160,7 +160,7 @@ def __str__(self) -> str:
if self.opcode is OP_TEXT:
# Decoding only the beginning and the end is needlessly hard.
# Decode the entire payload then elide later if necessary.
data = repr(self.data.decode())
data = repr(bytes(self.data).decode())
elif self.opcode is OP_BINARY:
# We'll show at most the first 16 bytes and the last 8 bytes.
# Encode just what we need, plus two dummy bytes to elide later.
Expand All @@ -178,7 +178,7 @@ def __str__(self) -> str:
# binary. If self.data is a memoryview, it has no decode() method,
# which raises AttributeError.
try:
data = repr(self.data.decode())
data = repr(bytes(self.data).decode())
coding = "text"
except (UnicodeDecodeError, AttributeError):
binary = self.data
Expand Down
11 changes: 11 additions & 0 deletions tests/extensions/test_permessage_deflate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
import os
import unittest

from websockets.exceptions import (
Expand Down Expand Up @@ -167,6 +168,16 @@ def test_encode_decode_fragmented_binary_frame(self):
self.assertEqual(dec_frame1, frame1)
self.assertEqual(dec_frame2, frame2)

def test_encode_decode_large_frame(self):
# There is a separate code path that avoids copying data
# when frames are larger than 2kB. Test it for coverage.
frame = Frame(OP_BINARY, os.urandom(4096))

enc_frame = self.extension.encode(frame)
dec_frame = self.extension.decode(enc_frame)

self.assertEqual(dec_frame, frame)

def test_no_decode_text_frame(self):
frame = Frame(OP_TEXT, "café".encode())

Expand Down

0 comments on commit 524dd4a

Please sign in to comment.