@@ -10,15 +10,18 @@ pub(crate) mod py;
1010mod range_to_sequence;
1111
1212use arrow_array:: { Array as ArrowArray , ArrayRef as ArrowArrayRef } ;
13+ use pyo3:: IntoPyObjectExt ;
1314use pyo3:: exceptions:: { PyTypeError , PyValueError } ;
1415use pyo3:: prelude:: * ;
1516use pyo3:: types:: { PyDict , PyList , PyRange , PyRangeMethods } ;
17+ use pyo3_bytes:: PyBytes ;
1618use vortex:: arrays:: ChunkedVTable ;
1719use vortex:: arrow:: IntoArrowArray ;
1820use vortex:: compute:: { Operator , compare, take} ;
1921use vortex:: dtype:: { DType , Nullability , PType , match_each_integer_ptype} ;
2022use vortex:: error:: VortexError ;
2123use vortex:: { Array , ArrayRef , ToCanonical } ;
24+ use vortex_ipc:: messages:: { EncoderMessage , MessageEncoder } ;
2225
2326use crate :: arrays:: native:: PyNativeArray ;
2427use crate :: arrays:: py:: { PyPythonArray , PythonArray } ;
@@ -653,4 +656,82 @@ impl PyArray {
653656 . map ( |buffer| buffer. to_vec ( ) )
654657 . collect ( ) )
655658 }
659+
660+ /// Support for Python's pickle protocol.
661+ ///
662+ /// This method serializes the array using Vortex IPC format and returns
663+ /// the data needed for pickle to reconstruct the array.
664+ fn __reduce__ < ' py > (
665+ slf : & ' py Bound < ' py , Self > ,
666+ ) -> PyResult < ( Bound < ' py , PyAny > , Bound < ' py , PyAny > ) > {
667+ let py = slf. py ( ) ;
668+ let array = PyArrayRef :: extract_bound ( slf. as_any ( ) ) ?. into_inner ( ) ;
669+
670+ let mut encoder = MessageEncoder :: default ( ) ;
671+ let buffers = encoder. encode ( EncoderMessage :: Array ( & * array) ) ;
672+
673+ // concat all buffers
674+ let mut serialized = Vec :: new ( ) ;
675+ for buf in buffers. iter ( ) {
676+ serialized. extend_from_slice ( buf) ;
677+ }
678+
679+ let dtype_buffers = encoder. encode ( EncoderMessage :: DType ( array. dtype ( ) ) ) ;
680+ let mut dtype_bytes = Vec :: new ( ) ;
681+ for buf in dtype_buffers. iter ( ) {
682+ dtype_bytes. extend_from_slice ( buf) ;
683+ }
684+
685+ let vortex_module = PyModule :: import ( py, "vortex" ) ?;
686+ let unpickle_fn = vortex_module. getattr ( "_unpickle_array" ) ?;
687+
688+ let args = ( serialized, dtype_bytes) . into_pyobject ( py) ?;
689+ Ok ( ( unpickle_fn, args. into_any ( ) ) )
690+ }
691+
692+ /// Support for Python's pickle protocol with protocol version awareness.
693+ ///
694+ /// When protocol >= 5, this uses PickleBuffer for out-of-band buffer transfer,
695+ /// which avoids copying large data buffers.
696+ fn __reduce_ex__ < ' py > (
697+ slf : & ' py Bound < ' py , Self > ,
698+ protocol : i32 ,
699+ ) -> PyResult < ( Bound < ' py , PyAny > , Bound < ' py , PyAny > ) > {
700+ let py = slf. py ( ) ;
701+
702+ if protocol < 5 {
703+ return Self :: __reduce__ ( slf) ;
704+ }
705+
706+ let array = PyArrayRef :: extract_bound ( slf. as_any ( ) ) ?. into_inner ( ) ;
707+
708+ let mut encoder = MessageEncoder :: default ( ) ;
709+ let array_buffers = encoder. encode ( EncoderMessage :: Array ( & * array) ) ;
710+ let dtype_buffers = encoder. encode ( EncoderMessage :: DType ( array. dtype ( ) ) ) ;
711+
712+ let pickle_module = PyModule :: import ( py, "pickle" ) ?;
713+ let pickle_buffer_class = pickle_module. getattr ( "PickleBuffer" ) ?;
714+
715+ let mut pickle_buffers = Vec :: new ( ) ;
716+ for buf in array_buffers. into_iter ( ) {
717+ // PyBytes wraps bytes::Bytes and implements the buffer protocol
718+ // This allows PickleBuffer to reference the data without copying
719+ let py_bytes = PyBytes :: new ( buf) . into_py_any ( py) ?;
720+ let pickle_buffer = pickle_buffer_class. call1 ( ( py_bytes, ) ) ?;
721+ pickle_buffers. push ( pickle_buffer) ;
722+ }
723+
724+ let mut dtype_pickle_buffers = Vec :: new ( ) ;
725+ for buf in dtype_buffers. into_iter ( ) {
726+ let py_bytes = PyBytes :: new ( buf) . into_py_any ( py) ?;
727+ let pickle_buffer = pickle_buffer_class. call1 ( ( py_bytes, ) ) ?;
728+ dtype_pickle_buffers. push ( pickle_buffer) ;
729+ }
730+
731+ let vortex_module = PyModule :: import ( py, "vortex" ) ?;
732+ let unpickle_fn = vortex_module. getattr ( "_unpickle_array_p5" ) ?;
733+
734+ let args = ( pickle_buffers, dtype_pickle_buffers) . into_pyobject ( py) ?;
735+ Ok ( ( unpickle_fn, args. into_any ( ) ) )
736+ }
656737}
0 commit comments