@@ -158,16 +158,18 @@ def _encode_tensor(
158
158
self , obj : torch .Tensor
159
159
) -> tuple [str , tuple [int , ...], Union [int , memoryview ]]:
160
160
assert self .aux_buffers is not None
161
+ # this creates a copy of the tensor if it's not already contiguous
162
+ obj = obj .contiguous ()
161
163
# view the tensor as a 1D array of bytes
162
- arr = obj .flatten ( ).view (torch .uint8 ).numpy ()
164
+ arr = obj .view (( obj . numel (), ) ).view (torch .uint8 ).numpy ()
163
165
if obj .nbytes < self .size_threshold :
164
166
# Smaller tensors are encoded inline, just like ndarrays.
165
167
data = msgpack .Ext (CUSTOM_TYPE_RAW_VIEW , arr .data )
166
168
else :
167
169
# Otherwise encode index of backing buffer to avoid copy.
168
170
data = len (self .aux_buffers )
169
171
self .aux_buffers .append (arr .data )
170
- dtype = str (obj .dtype ). removeprefix ( " torch." )
172
+ dtype = str (obj .dtype )[ 6 :] # remove ' torch.' prefix
171
173
return dtype , obj .shape , data
172
174
173
175
def _encode_nested_tensors (self , nt : NestedTensors ) -> Any :
@@ -243,7 +245,7 @@ def _decode_ndarray(self, arr: Any) -> np.ndarray:
243
245
# zero-copy decode. We assume the ndarray will not be kept around,
244
246
# as it now locks the whole received message buffer in memory.
245
247
buffer = self .aux_buffers [data ] if isinstance (data , int ) else data
246
- return np .frombuffer (buffer , dtype = dtype ). reshape ( shape )
248
+ return np .ndarray (buffer = buffer , dtype = np . dtype ( dtype ), shape = shape )
247
249
248
250
def _decode_tensor (self , arr : Any ) -> torch .Tensor :
249
251
dtype , shape , data = arr
@@ -252,15 +254,12 @@ def _decode_tensor(self, arr: Any) -> torch.Tensor:
252
254
# not complain about a readonly memoryview.
253
255
buffer = self .aux_buffers [data ] if isinstance (data , int ) \
254
256
else bytearray (data )
257
+ # Create numpy wrapper around the bytes
258
+ arr = np .ndarray (buffer = buffer , dtype = np .uint8 , shape = (len (buffer ), ))
255
259
torch_dtype = getattr (torch , dtype )
256
260
assert isinstance (torch_dtype , torch .dtype )
257
- if not buffer : # torch.frombuffer doesn't like empty buffers
258
- assert 0 in shape
259
- return torch .empty (shape , dtype = torch_dtype )
260
- # Create uint8 array
261
- arr = torch .frombuffer (buffer , dtype = torch .uint8 )
262
261
# Convert back to proper shape & type
263
- return arr .view (torch_dtype ).view (shape )
262
+ return torch . from_numpy ( arr ) .view (torch_dtype ).view (shape )
264
263
265
264
def _decode_mm_items (self , obj : list ) -> list [MultiModalKwargsItem ]:
266
265
decoded_items = []
0 commit comments