Skip to content

Commit

Permalink
Add support to decompress zstd (#812)
Browse files Browse the repository at this point in the history
  • Loading branch information
edenhaus authored Feb 9, 2025
1 parent 642b5dc commit fc74a93
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 23 deletions.
29 changes: 29 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ log = "0.4.25"
pyo3 = "0.23.3"
pyo3-log = "0.12.1"
svg = "0.18.0"
zstd = "0.13.2"

[dev-dependencies]
rstest = "0.24.0"
6 changes: 3 additions & 3 deletions deebot_client/commands/json/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from deebot_client.events.map import CachedMapInfoEvent
from deebot_client.logging_filter import get_logger
from deebot_client.message import HandlingResult, HandlingState, MessageBodyDataDict
from deebot_client.rs.util import decompress_7z_base64_data
from deebot_client.rs.util import decompress_base64_data

from .common import JsonCommandWithMessageHandling

Expand Down Expand Up @@ -275,7 +275,7 @@ def _handle_body_data_dict(
# This command is used by new and old bots
if data.get("compress", 0) == 1:
# Newer bot's return coordinates as base64 decoded string
coordinates = decompress_7z_base64_data(data["value"]).decode()
coordinates = decompress_base64_data(data["value"]).decode()
else:
# Older bot's return coordinates direct as comma/semicolon separated list
coordinates = data["value"]
Expand Down Expand Up @@ -305,7 +305,7 @@ def _get_subset_ids(
) -> list[int] | None:
"""Return subset ids."""
# subset is based64 7z compressed
subsets = json.loads(decompress_7z_base64_data(data["subsets"]).decode())
subsets = json.loads(decompress_base64_data(data["subsets"]).decode())

match data["type"]:
case MapSetType.ROOMS:
Expand Down
4 changes: 2 additions & 2 deletions deebot_client/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from .logging_filter import get_logger
from .models import Room
from .rs.map import MapData as MapDataRs
from .rs.util import decompress_7z_base64_data
from .rs.util import decompress_base64_data
from .util import (
OnChangedDict,
OnChangedList,
Expand Down Expand Up @@ -311,7 +311,7 @@ def image(self) -> Image.Image:

def update_points(self, base64_data: str) -> None:
"""Add map piece points."""
decoded = decompress_7z_base64_data(base64_data)
decoded = decompress_base64_data(base64_data)
old_crc32 = self._crc32
self._crc32 = zlib.crc32(decoded)

Expand Down
4 changes: 2 additions & 2 deletions deebot_client/rs/util.pyi
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
def decompress_7z_base64_data(value: str) -> bytes:
"""Decompress base64 decoded 7z compressed string."""
def decompress_base64_data(value: str) -> bytes:
"""Decompress base64 decoded 7z compressed string by using lzma or zstd."""
4 changes: 2 additions & 2 deletions src/map.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::error::Error;
use std::io::Cursor;

use super::util::decompress_7z_base64_data;
use super::util::decompress_base64_data;
use base64::engine::general_purpose;
use base64::Engine;
use byteorder::{LittleEndian, ReadBytesExt};
Expand Down Expand Up @@ -41,7 +41,7 @@ fn process_trace_points(trace_points: &[u8]) -> Result<Vec<TracePoint>, Box<dyn
}

fn extract_trace_points(value: String) -> Result<Vec<TracePoint>, Box<dyn Error>> {
let decompressed_data = decompress_7z_base64_data(value)?;
let decompressed_data = decompress_base64_data(value)?;
process_trace_points(&decompressed_data)
}

Expand Down
36 changes: 29 additions & 7 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,17 @@ use liblzma::stream::Stream;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;

pub fn decompress_7z_base64_data(value: String) -> Result<Vec<u8>, Box<dyn Error>> {
let mut bytes = general_purpose::STANDARD.decode(value)?;
pub fn decompress_base64_data(value: String) -> Result<Vec<u8>, Box<dyn Error>> {
let bytes = general_purpose::STANDARD.decode(value)?;

if is_zstd_compressed(&bytes) {
decompress_zstd(&bytes)
} else {
decompress_lzma(bytes)
}
}

fn decompress_lzma(mut bytes: Vec<u8>) -> Result<Vec<u8>, Box<dyn Error>> {
if bytes.len() < 8 {
return Err("Invalid 7z compressed data".into());
}
Expand All @@ -27,13 +35,27 @@ pub fn decompress_7z_base64_data(value: String) -> Result<Vec<u8>, Box<dyn Error
Ok(result)
}

/// Decompress base64 decoded 7z compressed string.
#[pyfunction(name = "decompress_7z_base64_data")]
fn python_decompress_7z_base64_data(value: String) -> Result<Vec<u8>, PyErr> {
decompress_7z_base64_data(value).map_err(|err| PyValueError::new_err(err.to_string()))
fn decompress_zstd(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
let mut decoder = zstd::Decoder::new(bytes)?;
let mut result = Vec::new();
decoder.read_to_end(&mut result)?;

Ok(result)
}

fn is_zstd_compressed(bytes: &[u8]) -> bool {
// Implement a check to determine if the data is zstd-compressed
// That the data starts with the magic bytes of zstd-compressed data
bytes.starts_with(&[0x28, 0xb5, 0x2f, 0xfd])
}

/// Decompress base64 decoded compressed string by using lzma or zstd
#[pyfunction(name = "decompress_base64_data")]
fn python_decompress_base64_data(value: String) -> Result<Vec<u8>, PyErr> {
decompress_base64_data(value).map_err(|err| PyValueError::new_err(err.to_string()))
}

pub fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(python_decompress_7z_base64_data, m)?)?;
m.add_function(wrap_pyfunction!(python_decompress_base64_data, m)?)?;
Ok(())
}
33 changes: 26 additions & 7 deletions tests/rs/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import pytest

from deebot_client.rs.util import decompress_7z_base64_data
from deebot_client.rs.util import decompress_base64_data

if TYPE_CHECKING:
from pytest_codspeed import BenchmarkFixture
Expand Down Expand Up @@ -36,18 +36,37 @@
],
ids=["1", "2", "3", "4"],
)
def test_decompress_7z_base64_data(
def test_decompress_base64_data_lzma(
benchmark: BenchmarkFixture, input: str, expected: bytes
) -> None:
"""Test decompress_7z_base64_data function."""
"""Test decompress_base64_data function with lzma base64 values."""
# Benchmark only the production function
result = benchmark(decompress_7z_base64_data, input)
result = benchmark(decompress_base64_data, input)
assert result == expected

# Verify that the old python function is producing the same result
assert _decompress_7z_base64_data_python(input) == result


@pytest.mark.parametrize(
("input", "expected"),
[
(
"KLUv/SB//QEAMgQKDKClbQC+WNsvI/5vYPMSO6jz8h7OwN2BYlTHRR2DYgSeurlRRyp2UAgALXwANbAWWqAuACQBKiDgFiUJ",
b"-624,-774;-524,-774;-474,-724;-424,-724;-374,-674;-124,-674;-24,-774;-74,-824;2325,-824;2375,-774;2425,-774;2425,1225;-624,1225",
),
],
ids=["1"],
)
def test_decompress_base64_data_zstd(
benchmark: BenchmarkFixture, input: str, expected: bytes
) -> None:
"""Test decompress_base64_data function with zstd base64 values."""
# Benchmark only the production function
result = benchmark(decompress_base64_data, input)
assert result == expected


@pytest.mark.parametrize(
("input", "expected_error"),
[
Expand All @@ -65,10 +84,10 @@ def test_decompress_7z_base64_data(
),
],
)
def test_decompress_7z_base64_data_errors(input: str, expected_error: str) -> None:
"""Test decompress_7z_base64_data function."""
def test_decompress_base64_data_errors(input: str, expected_error: str) -> None:
"""Test decompress_base64_data function."""
with pytest.raises(ValueError, match=expected_error):
assert decompress_7z_base64_data(input)
assert decompress_base64_data(input)


def _decompress_7z_base64_data_python(data: str) -> bytes:
Expand Down

0 comments on commit fc74a93

Please sign in to comment.