Skip to content

Commit 6818cd0

Browse files
committed
Revert "[Core] Improve Tensor serialisation (#18774)"
This reverts commit d73a945.
1 parent 643622b commit 6818cd0

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

vllm/v1/serial_utils.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -158,16 +158,18 @@ def _encode_tensor(
158158
self, obj: torch.Tensor
159159
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
160160
assert self.aux_buffers is not None
161+
# this creates a copy of the tensor if it's not already contiguous
162+
obj = obj.contiguous()
161163
# view the tensor as a 1D array of bytes
162-
arr = obj.flatten().view(torch.uint8).numpy()
164+
arr = obj.view((obj.numel(), )).view(torch.uint8).numpy()
163165
if obj.nbytes < self.size_threshold:
164166
# Smaller tensors are encoded inline, just like ndarrays.
165167
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
166168
else:
167169
# Otherwise encode index of backing buffer to avoid copy.
168170
data = len(self.aux_buffers)
169171
self.aux_buffers.append(arr.data)
170-
dtype = str(obj.dtype).removeprefix("torch.")
172+
dtype = str(obj.dtype)[6:] # remove 'torch.' prefix
171173
return dtype, obj.shape, data
172174

173175
def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
@@ -243,7 +245,7 @@ def _decode_ndarray(self, arr: Any) -> np.ndarray:
243245
# zero-copy decode. We assume the ndarray will not be kept around,
244246
# as it now locks the whole received message buffer in memory.
245247
buffer = self.aux_buffers[data] if isinstance(data, int) else data
246-
return np.frombuffer(buffer, dtype=dtype).reshape(shape)
248+
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)
247249

248250
def _decode_tensor(self, arr: Any) -> torch.Tensor:
249251
dtype, shape, data = arr
@@ -252,15 +254,12 @@ def _decode_tensor(self, arr: Any) -> torch.Tensor:
252254
# not complain about a readonly memoryview.
253255
buffer = self.aux_buffers[data] if isinstance(data, int) \
254256
else bytearray(data)
257+
# Create numpy wrapper around the bytes
258+
arr = np.ndarray(buffer=buffer, dtype=np.uint8, shape=(len(buffer), ))
255259
torch_dtype = getattr(torch, dtype)
256260
assert isinstance(torch_dtype, torch.dtype)
257-
if not buffer: # torch.frombuffer doesn't like empty buffers
258-
assert 0 in shape
259-
return torch.empty(shape, dtype=torch_dtype)
260-
# Create uint8 array
261-
arr = torch.frombuffer(buffer, dtype=torch.uint8)
262261
# Convert back to proper shape & type
263-
return arr.view(torch_dtype).view(shape)
262+
return torch.from_numpy(arr).view(torch_dtype).view(shape)
264263

265264
def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
266265
decoded_items = []

0 commit comments

Comments
 (0)