Skip to content

Commit d46b101

Browse files
committed
pickling support
Signed-off-by: Onur Satici <[email protected]>
1 parent 47d1161 commit d46b101

File tree

11 files changed

+493
-2
lines changed

11 files changed

+493
-2
lines changed

Cargo.lock

Lines changed: 13 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ prost = "0.14"
156156
prost-build = "0.14"
157157
prost-types = "0.14"
158158
pyo3 = { version = "0.26.0" }
159+
pyo3-bytes = "0.4"
159160
pyo3-log = "0.13.0"
160161
rand = "0.9.0"
161162
rand_distr = "0.5"

vortex-python/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,16 @@ crate-type = ["rlib", "cdylib"]
2525
arrow-array = { workspace = true }
2626
arrow-data = { workspace = true }
2727
arrow-schema = { workspace = true }
28+
bytes = { workspace = true }
2829
itertools = { workspace = true }
2930
log = { workspace = true }
3031
mimalloc = { workspace = true }
3132
object_store = { workspace = true, features = ["aws", "gcp", "azure", "http"] }
3233
parking_lot = { workspace = true }
3334
pyo3 = { workspace = true, features = ["abi3", "abi3-py311"] }
35+
pyo3-bytes = { workspace = true }
3436
pyo3-log = { workspace = true }
3537
tokio = { workspace = true, features = ["fs", "rt-multi-thread"] }
3638
url = { workspace = true }
3739
vortex = { workspace = true, features = ["object_store", "python", "tokio"] }
40+
vortex-ipc = { workspace = true }

vortex-python/benchmark/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import hashlib
55
import math
66
import os
7+
from typing import cast
78

89
import pyarrow as pa
910
import pytest
@@ -34,3 +35,9 @@ def vxf(tmpdir_factory: pytest.TempPathFactory, request: pytest.FixtureRequest)
3435
a = vx.array(pa.table(columns)) # pyright: ignore[reportCallIssue, reportUnknownArgumentType, reportArgumentType]
3536
vx.io.write(a, str(fname))
3637
return vx.open(str(fname))
38+
39+
40+
@pytest.fixture(scope="session", params=[10_000, 2_000_000], ids=["small", "large"])
41+
def array_fixture(request: pytest.FixtureRequest) -> vx.Array:
42+
size = cast(int, request.param)
43+
return vx.array(pa.table({"x": list(range(size))}))
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
import pickle
5+
6+
import pytest
7+
from pytest_benchmark.fixture import BenchmarkFixture # pyright: ignore[reportMissingTypeStubs]
8+
9+
import vortex as vx
10+
11+
12+
@pytest.mark.parametrize("protocol", [4, 5], ids=lambda p: f"p{p}") # pyright: ignore[reportAny]
13+
@pytest.mark.parametrize("operation", ["dumps", "loads", "roundtrip"])
14+
@pytest.mark.benchmark(disable_gc=True)
15+
def test_pickle(
16+
benchmark: BenchmarkFixture,
17+
array_fixture: vx.Array,
18+
protocol: int,
19+
operation: str,
20+
):
21+
benchmark.group = f"pickle_p{protocol}"
22+
23+
if operation == "dumps":
24+
benchmark(lambda: pickle.dumps(array_fixture, protocol=protocol))
25+
elif operation == "loads":
26+
pickled_data = pickle.dumps(array_fixture, protocol=protocol)
27+
benchmark(lambda: pickle.loads(pickled_data)) # pyright: ignore[reportAny]
28+
elif operation == "roundtrip":
29+
benchmark(lambda: pickle.loads(pickle.dumps(array_fixture, protocol=protocol))) # pyright: ignore[reportAny]

vortex-python/python/vortex/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,13 @@
7171
scalar,
7272
)
7373
from ._lib.serde import ArrayContext, ArrayParts # pyright: ignore[reportMissingModuleSource]
74-
from .arrays import Array, PyArray, array
74+
from .arrays import (
75+
Array,
76+
PyArray,
77+
_unpickle_array, # pyright: ignore[reportPrivateUsage]
78+
_unpickle_array_p5, # pyright: ignore[reportPrivateUsage]
79+
array,
80+
)
7581
from .file import VortexFile, open
7682
from .scan import RepeatedScan
7783

@@ -156,6 +162,9 @@
156162
"Registry",
157163
"ArrayContext",
158164
"ArrayParts",
165+
# Pickle
166+
"_unpickle_array",
167+
"_unpickle_array_p5",
159168
# File
160169
"VortexFile",
161170
"open",

vortex-python/python/vortex/_lib/serde.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ from typing import final
55

66
import pyarrow as pa
77

8+
from .arrays import Array
89
from .dtype import DType
910

1011
@final
@@ -26,3 +27,8 @@ class ArrayParts:
2627
@final
2728
class ArrayContext:
2829
def __len__(self) -> int: ...
30+
31+
def decode_ipc_array(array_bytes: bytes, dtype_bytes: bytes) -> Array: ...
32+
def decode_ipc_array_buffers(
33+
array_buffers: list[bytes | memoryview], dtype_buffers: list[bytes | memoryview]
34+
) -> Array: ...

vortex-python/python/vortex/arrays.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111

1212
import vortex._lib.arrays as _arrays # pyright: ignore[reportMissingModuleSource]
1313
from vortex._lib.dtype import DType # pyright: ignore[reportMissingModuleSource]
14-
from vortex._lib.serde import ArrayContext, ArrayParts # pyright: ignore[reportMissingModuleSource]
14+
from vortex._lib.serde import ( # pyright: ignore[reportMissingModuleSource]
15+
ArrayContext,
16+
ArrayParts,
17+
decode_ipc_array,
18+
decode_ipc_array_buffers,
19+
)
1520

1621
try:
1722
import pandas
@@ -466,3 +471,21 @@ def decode(cls, parts: ArrayParts, ctx: ArrayContext, dtype: DType, len: int) ->
466471
current array. Implementations of this function should validate this information, and then construct
467472
a new array.
468473
"""
474+
475+
476+
def _unpickle_array(array_bytes: bytes, dtype_bytes: bytes) -> Array: # pyright: ignore[reportUnusedFunction]
477+
"""Unpickle a Vortex array from IPC-encoded bytes.
478+
479+
This is an internal function used by the pickle module.
480+
"""
481+
return decode_ipc_array(array_bytes, dtype_bytes)
482+
483+
484+
def _unpickle_array_p5(array_buffers: list[bytes | memoryview], dtype_buffers: list[bytes | memoryview]) -> Array: # pyright: ignore[reportUnusedFunction]
485+
"""Unpickle a Vortex array from out-of-band PickleBuffers.
486+
487+
This is an internal function used by the pickle module. When
488+
pickle protocol 5 is supported, this methods will be called on unpickle,
489+
saving one extra copy operation for array buffers.
490+
"""
491+
return decode_ipc_array_buffers(array_buffers, dtype_buffers)

vortex-python/src/arrays/mod.rs

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,18 @@ pub(crate) mod py;
1010
mod range_to_sequence;
1111

1212
use arrow_array::{Array as ArrowArray, ArrayRef as ArrowArrayRef};
13+
use pyo3::IntoPyObjectExt;
1314
use pyo3::exceptions::{PyTypeError, PyValueError};
1415
use pyo3::prelude::*;
1516
use pyo3::types::{PyDict, PyList, PyRange, PyRangeMethods};
17+
use pyo3_bytes::PyBytes;
1618
use vortex::arrays::ChunkedVTable;
1719
use vortex::arrow::IntoArrowArray;
1820
use vortex::compute::{Operator, compare, take};
1921
use vortex::dtype::{DType, Nullability, PType, match_each_integer_ptype};
2022
use vortex::error::VortexError;
2123
use vortex::{Array, ArrayRef, ToCanonical};
24+
use vortex_ipc::messages::{EncoderMessage, MessageEncoder};
2225

2326
use crate::arrays::native::PyNativeArray;
2427
use crate::arrays::py::{PyPythonArray, PythonArray};
@@ -653,4 +656,82 @@ impl PyArray {
653656
.map(|buffer| buffer.to_vec())
654657
.collect())
655658
}
659+
660+
/// Support for Python's pickle protocol.
661+
///
662+
/// This method serializes the array using Vortex IPC format and returns
663+
/// the data needed for pickle to reconstruct the array.
664+
fn __reduce__<'py>(
665+
slf: &'py Bound<'py, Self>,
666+
) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
667+
let py = slf.py();
668+
let array = PyArrayRef::extract_bound(slf.as_any())?.into_inner();
669+
670+
let mut encoder = MessageEncoder::default();
671+
let buffers = encoder.encode(EncoderMessage::Array(&*array));
672+
673+
// concat all buffers
674+
let mut serialized = Vec::new();
675+
for buf in buffers.iter() {
676+
serialized.extend_from_slice(buf);
677+
}
678+
679+
let dtype_buffers = encoder.encode(EncoderMessage::DType(array.dtype()));
680+
let mut dtype_bytes = Vec::new();
681+
for buf in dtype_buffers.iter() {
682+
dtype_bytes.extend_from_slice(buf);
683+
}
684+
685+
let vortex_module = PyModule::import(py, "vortex")?;
686+
let unpickle_fn = vortex_module.getattr("_unpickle_array")?;
687+
688+
let args = (serialized, dtype_bytes).into_pyobject(py)?;
689+
Ok((unpickle_fn, args.into_any()))
690+
}
691+
692+
/// Support for Python's pickle protocol with protocol version awareness.
693+
///
694+
/// When protocol >= 5, this uses PickleBuffer for out-of-band buffer transfer,
695+
/// which avoids copying large data buffers.
696+
fn __reduce_ex__<'py>(
697+
slf: &'py Bound<'py, Self>,
698+
protocol: i32,
699+
) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
700+
let py = slf.py();
701+
702+
if protocol < 5 {
703+
return Self::__reduce__(slf);
704+
}
705+
706+
let array = PyArrayRef::extract_bound(slf.as_any())?.into_inner();
707+
708+
let mut encoder = MessageEncoder::default();
709+
let array_buffers = encoder.encode(EncoderMessage::Array(&*array));
710+
let dtype_buffers = encoder.encode(EncoderMessage::DType(array.dtype()));
711+
712+
let pickle_module = PyModule::import(py, "pickle")?;
713+
let pickle_buffer_class = pickle_module.getattr("PickleBuffer")?;
714+
715+
let mut pickle_buffers = Vec::new();
716+
for buf in array_buffers.into_iter() {
717+
// PyBytes wraps bytes::Bytes and implements the buffer protocol
718+
// This allows PickleBuffer to reference the data without copying
719+
let py_bytes = PyBytes::new(buf).into_py_any(py)?;
720+
let pickle_buffer = pickle_buffer_class.call1((py_bytes,))?;
721+
pickle_buffers.push(pickle_buffer);
722+
}
723+
724+
let mut dtype_pickle_buffers = Vec::new();
725+
for buf in dtype_buffers.into_iter() {
726+
let py_bytes = PyBytes::new(buf).into_py_any(py)?;
727+
let pickle_buffer = pickle_buffer_class.call1((py_bytes,))?;
728+
dtype_pickle_buffers.push(pickle_buffer);
729+
}
730+
731+
let vortex_module = PyModule::import(py, "vortex")?;
732+
let unpickle_fn = vortex_module.getattr("_unpickle_array_p5")?;
733+
734+
let args = (pickle_buffers, dtype_pickle_buffers).into_pyobject(py)?;
735+
Ok((unpickle_fn, args.into_any()))
736+
}
656737
}

0 commit comments

Comments
 (0)