|
6 | 6 |
|
7 | 7 | from dataclasses import dataclass, field |
8 | 8 | from enum import auto, Enum |
| 9 | +from itertools import product |
9 | 10 | from typing import Dict, List, Mapping, Optional, Set |
10 | 11 |
|
11 | 12 | from monarch.actor import Actor, endpoint |
@@ -61,6 +62,46 @@ def assert_initialized(self) -> None: |
61 | 62 | self.is_initialized |
62 | 63 | ), "Please call torchstore.initialize before attempting to use store." |
63 | 64 |
|
| 65 | + def _is_dtensor_fully_committed( |
| 66 | + self, key: str, volume_map: Dict[str, StorageInfo] |
| 67 | + ) -> bool: |
| 68 | + """ |
| 69 | + Check if all shards of a DTensor have been committed. |
| 70 | +
|
| 71 | + For a DTensor to be fully committed, we need all coordinates in the mesh |
| 72 | + to have been stored. The mesh_shape tells us the total number of shards, |
| 73 | + and coordinates tell us which shards we have. |
| 74 | +
|
| 75 | + Args: |
| 76 | + key (str): The key to check. |
| 77 | + volume_map (Dict[str, StorageInfo]): Mapping from storage volume IDs to StorageInfo. |
| 78 | +
|
| 79 | + Returns: |
| 80 | + bool: True if fully committed, False if partial. |
| 81 | + """ |
| 82 | + # Collect all tensor slices across all storage volumes |
| 83 | + all_slices = set() |
| 84 | + mesh_shape = None |
| 85 | + |
| 86 | + for storage_info in volume_map.values(): |
| 87 | + if storage_info.object_type != ObjectType.TENSOR_SLICE: |
| 88 | + return True # Not a DTensor, so it's "fully committed" |
| 89 | + |
| 90 | + for tensor_slice in storage_info.tensor_slices: |
| 91 | + all_slices.add(tensor_slice.coordinates) |
| 92 | + if mesh_shape is None: |
| 93 | + mesh_shape = tensor_slice.mesh_shape |
| 94 | + else: |
| 95 | + assert ( |
| 96 | + mesh_shape == tensor_slice.mesh_shape |
| 97 | + ), "Inconsistent mesh shapes in stored slices" |
| 98 | + |
| 99 | + # Generate all expected coordinates for the mesh |
| 100 | + expected_coords = set(product(*(range(s) for s in mesh_shape))) |
| 101 | + |
| 102 | + # Check if we have all coordinates |
| 103 | + return all_slices == expected_coords |
| 104 | + |
64 | 105 | @endpoint |
65 | 106 | async def init( |
66 | 107 | self, |
@@ -116,13 +157,25 @@ def locate_volumes( |
116 | 157 | objects containing metadata about the stored data shards. |
117 | 158 |
|
118 | 159 | Raises: |
119 | | - KeyError: If the key is not found in any storage volumes. |
| 160 | + KeyError: If the key is not found in any storage volumes, or if the key |
| 161 | + is a DTensor that is only partially committed. |
120 | 162 | """ |
121 | 163 | self.assert_initialized() |
122 | 164 |
|
123 | 165 | if key not in self.keys_to_storage_volumes: |
124 | 166 | raise KeyError(f"Unable to locate {key} in any storage volumes.") |
125 | | - return self.keys_to_storage_volumes[key] |
| 167 | + |
| 168 | + volume_map = self.keys_to_storage_volumes[key] |
| 169 | + |
| 170 | + # Check if this is a DTensor and if it's fully committed |
| 171 | + if not self._is_dtensor_fully_committed(key, volume_map): |
| 172 | + raise KeyError( |
| 173 | + f"DTensor '{key}' is only partially committed. " |
| 174 | + f"Not all shards have been stored yet. " |
| 175 | + f"Please ensure all ranks complete their put() operations." |
| 176 | + ) |
| 177 | + |
| 178 | + return volume_map |
126 | 179 |
|
127 | 180 | @endpoint |
128 | 181 | def notify_put(self, key: str, request: Request, storage_volume_id: str) -> None: |
|
0 commit comments