Skip to content

Commit 871fa11

Browse files
kaiyuan-lidcci
andauthored
Throw KeyError when getting partially pushed dtensor (#49)
* raise KeyError at key miss * just raise key error * just raise key error * [torchstore] Rework the readme. * Update to account for Lucas' comments. * test * sync * test update * Add partial DTensor commit detection with file-based sync * Update README.md to match upstream/main and remove PR comment * Simplify test_partial_put by removing sync primitives and using ranks_to_skip_put * verify exists in test * fmt --------- Co-authored-by: Davide Italiano <[email protected]>
1 parent 49b6b7a commit 871fa11

File tree

3 files changed

+115
-4
lines changed

3 files changed

+115
-4
lines changed

tests/test_tensor_slice.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import os
8+
import tempfile
89

910
import pytest
1011
import torch
@@ -120,8 +121,6 @@ async def test_tensor_slice_inplace():
120121
@pytest.mark.asyncio
121122
async def test_put_dtensor_get_full_tensor():
122123
"""Test basic DTensor put/get functionality with separate put and get meshes using shared DTensorActor"""
123-
import tempfile
124-
125124
await ts.initialize(num_storage_volumes=2, strategy=ts.LocalRankStrategy())
126125

127126
original_tensor = torch.arange(16).reshape(4, 4).float()
@@ -151,5 +150,53 @@ async def test_put_dtensor_get_full_tensor():
151150
await ts.shutdown()
152151

153152

153+
@pytest.mark.asyncio
154+
async def test_partial_put():
155+
"""
156+
Verify the behavior when a dtensor is partially put.
157+
1. Create two put actors. Each of them should put half of a DTensor.
158+
2. Rank 1 will skip the put operation (using ranks_to_skip_put=[1]).
159+
3. After rank 0 completes its put, we call get() which should raise a KeyError
160+
because the DTensor is not fully committed (only rank 0's shard is stored).
161+
"""
162+
163+
await ts.initialize(num_storage_volumes=2, strategy=ts.LocalRankStrategy())
164+
165+
original_tensor = torch.arange(16).reshape(4, 4).float()
166+
167+
with tempfile.TemporaryDirectory() as filesystem_store_dir:
168+
try:
169+
put_mesh = await spawn_actors(
170+
2,
171+
DTensorActor,
172+
"dtensor_put_mesh",
173+
mesh_shape=(2,),
174+
original_tensor=original_tensor,
175+
placements=[Shard(0)],
176+
file_store_name=os.path.join(filesystem_store_dir, "put_test"),
177+
visible_devices="0,1",
178+
ranks_to_skip_put=[1], # Rank 1 will skip the put
179+
)
180+
181+
# Execute the put - rank 0 will put, rank 1 will skip
182+
await put_mesh.do_put.call()
183+
184+
assert not await ts.exists("test_key")
185+
# Try to get the tensor - should raise KeyError because only rank 0 has committed
186+
with pytest.raises(KeyError) as exc_info:
187+
await ts.get("test_key")
188+
189+
# Verify the error message mentions partial commit
190+
assert "partially committed" in str(
191+
exc_info.value
192+
), f"Error message should mention partial commit: {exc_info.value}"
193+
194+
finally:
195+
# Clean up process groups
196+
await put_mesh.destroy_process_group.call()
197+
await put_mesh._proc_mesh.stop()
198+
await ts.shutdown()
199+
200+
154201
if __name__ == "__main__":
155202
main(__file__)

tests/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from itertools import product
1010
from logging import getLogger
1111

12+
from typing import List
13+
1214
import pytest
1315
import torch
1416
import torchstore as ts
@@ -53,13 +55,17 @@ def __init__(
5355
placements,
5456
file_store_name,
5557
visible_devices="0,1,2,3,4,5,6,7",
58+
ranks_to_skip_put: (
59+
List[int] | None
60+
) = None, # ranks that should skip put operation
5661
):
5762
self.rank = current_rank().rank
5863
self.mesh_shape = mesh_shape
5964
self.world_size = math.prod(mesh_shape)
6065
self.original_tensor = original_tensor
6166
self.placements = placements
6267
self.file_store_name = file_store_name
68+
self.ranks_to_skip_put = ranks_to_skip_put or []
6369

6470
# torchstore will fail without this (see LocalRankStrategy)
6571
os.environ["LOCAL_RANK"] = str(self.rank)
@@ -95,6 +101,11 @@ async def do_put(self):
95101
tensor = self.original_tensor.to("cpu")
96102
dtensor = distribute_tensor(tensor, device_mesh, placements=self.placements)
97103

104+
# Skip put if this rank is in the skip list
105+
if self.rank in self.ranks_to_skip_put:
106+
self.rlog(f"Skipping put for rank {self.rank}")
107+
return
108+
98109
self.rlog(f"calling put with {dtensor=}")
99110
await ts.put(self.shared_key, dtensor)
100111

torchstore/controller.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from dataclasses import dataclass, field
88
from enum import auto, Enum
9+
from itertools import product
910
from typing import Dict, List, Mapping, Optional, Set
1011

1112
from monarch.actor import Actor, endpoint
@@ -61,6 +62,46 @@ def assert_initialized(self) -> None:
6162
self.is_initialized
6263
), "Please call torchstore.initialize before attempting to use store."
6364

65+
def _is_dtensor_fully_committed(
66+
self, key: str, volume_map: Dict[str, StorageInfo]
67+
) -> bool:
68+
"""
69+
Check if all shards of a DTensor have been committed.
70+
71+
For a DTensor to be fully committed, we need all coordinates in the mesh
72+
to have been stored. The mesh_shape tells us the total number of shards,
73+
and coordinates tell us which shards we have.
74+
75+
Args:
76+
key (str): The key to check.
77+
volume_map (Dict[str, StorageInfo]): Mapping from storage volume IDs to StorageInfo.
78+
79+
Returns:
80+
bool: True if fully committed, False if partial.
81+
"""
82+
# Collect all tensor slices across all storage volumes
83+
all_slices = set()
84+
mesh_shape = None
85+
86+
for storage_info in volume_map.values():
87+
if storage_info.object_type != ObjectType.TENSOR_SLICE:
88+
return True # Not a DTensor, so it's "fully committed"
89+
90+
for tensor_slice in storage_info.tensor_slices:
91+
all_slices.add(tensor_slice.coordinates)
92+
if mesh_shape is None:
93+
mesh_shape = tensor_slice.mesh_shape
94+
else:
95+
assert (
96+
mesh_shape == tensor_slice.mesh_shape
97+
), "Inconsistent mesh shapes in stored slices"
98+
99+
# Generate all expected coordinates for the mesh
100+
expected_coords = set(product(*(range(s) for s in mesh_shape)))
101+
102+
# Check if we have all coordinates
103+
return all_slices == expected_coords
104+
64105
@endpoint
65106
async def init(
66107
self,
@@ -116,13 +157,25 @@ def locate_volumes(
116157
objects containing metadata about the stored data shards.
117158
118159
Raises:
119-
KeyError: If the key is not found in any storage volumes.
160+
KeyError: If the key is not found in any storage volumes, or if the key
161+
is a DTensor that is only partially committed.
120162
"""
121163
self.assert_initialized()
122164

123165
if key not in self.keys_to_storage_volumes:
124166
raise KeyError(f"Unable to locate {key} in any storage volumes.")
125-
return self.keys_to_storage_volumes[key]
167+
168+
volume_map = self.keys_to_storage_volumes[key]
169+
170+
# Check if this is a DTensor and if it's fully committed
171+
if not self._is_dtensor_fully_committed(key, volume_map):
172+
raise KeyError(
173+
f"DTensor '{key}' is only partially committed. "
174+
f"Not all shards have been stored yet. "
175+
f"Please ensure all ranks complete their put() operations."
176+
)
177+
178+
return volume_map
126179

127180
@endpoint
128181
def notify_put(self, key: str, request: Request, storage_volume_id: str) -> None:

0 commit comments

Comments
 (0)