|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import asyncio |
| 8 | +import math |
| 9 | +import os |
| 10 | +import tempfile |
| 11 | +import time |
| 12 | +from logging import getLogger |
| 13 | + |
| 14 | +import torch |
| 15 | +import torchstore as ts |
| 16 | +from monarch.actor import Actor, current_rank, endpoint |
| 17 | +from torch.distributed._tensor import distribute_tensor, Shard |
| 18 | +from torch.distributed.device_mesh import init_device_mesh |
| 19 | +from torchstore.logging import init_logging |
| 20 | +from torchstore.utils import spawn_actors |
| 21 | + |
| 22 | +# ANSI escape codes for colored output |
| 23 | +YELLOW = "\033[93m" |
| 24 | +BOLD = "\033[1m" |
| 25 | +RESET = "\033[0m" |
| 26 | + |
| 27 | + |
| 28 | +def print_yellow(text): |
| 29 | + """Print text in yellow color""" |
| 30 | + print(f"{YELLOW}{BOLD}{text}{RESET}") |
| 31 | + |
| 32 | + |
| 33 | +init_logging() |
| 34 | +logger = getLogger(__name__) |
| 35 | + |
| 36 | + |
| 37 | +class DTensorActor(Actor): |
| 38 | + """Test class used to verify correctness of resharding across different shardings. |
| 39 | + Currently only supports a single tensor |
| 40 | + """ |
| 41 | + |
| 42 | + shared_key = "test_key" |
| 43 | + |
| 44 | + def __init__( |
| 45 | + self, |
| 46 | + mesh_shape, |
| 47 | + original_tensor, |
| 48 | + placements, |
| 49 | + file_store_name, |
| 50 | + visible_devices="0,1,2,3,4,5,6,7", |
| 51 | + ): |
| 52 | + self.rank = current_rank().rank |
| 53 | + self.mesh_shape = mesh_shape |
| 54 | + self.world_size = math.prod(mesh_shape) |
| 55 | + self.original_tensor = original_tensor |
| 56 | + self.placements = placements |
| 57 | + self.file_store_name = file_store_name |
| 58 | + |
| 59 | + # torchstore will fail without this (see LocalRankStrategy) |
| 60 | + os.environ["LOCAL_RANK"] = str(self.rank) |
| 61 | + |
| 62 | + # this is only necessary for nccl, but we're not using it in this test. |
| 63 | + os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices |
| 64 | + |
| 65 | + def rlog(self, msg): |
| 66 | + # TODO: set to 'info' once this is fixed in monarch (which currently is hiding logs :/) |
| 67 | + logger.info(f"rank: {self.rank} {msg}") |
| 68 | + |
| 69 | + def initialize_distributed(self): |
| 70 | + self.rlog(f"Initialize process group using {self.file_store_name=} ") |
| 71 | + torch.distributed.init_process_group( |
| 72 | + backend="gloo", |
| 73 | + rank=self.rank, |
| 74 | + world_size=self.world_size, |
| 75 | + init_method=f"file://{self.file_store_name}", |
| 76 | + ) |
| 77 | + |
| 78 | + # this barrier is more to make sure torch.distibuted is working |
| 79 | + self.rlog("barrrer") |
| 80 | + torch.distributed.barrier() |
| 81 | + |
| 82 | + @endpoint |
| 83 | + async def do_put(self): |
| 84 | + self.initialize_distributed() |
| 85 | + |
| 86 | + self.rlog("Create device mesh") |
| 87 | + device_mesh = init_device_mesh("cpu", self.mesh_shape) |
| 88 | + |
| 89 | + self.rlog("distributing dtensor") |
| 90 | + tensor = self.original_tensor.to("cpu") |
| 91 | + dtensor = distribute_tensor(tensor, device_mesh, placements=self.placements) |
| 92 | + |
| 93 | + self.rlog(f"calling put with {dtensor=}") |
| 94 | + await ts.put(self.shared_key, dtensor) |
| 95 | + |
| 96 | + @endpoint |
| 97 | + async def do_get(self): |
| 98 | + self.initialize_distributed() |
| 99 | + |
| 100 | + self.rlog("Create device mesh") |
| 101 | + # TODO: nccl is giving me a weird error on process group split for 2d mesh |
| 102 | + device_mesh = init_device_mesh("cpu", self.mesh_shape) |
| 103 | + |
| 104 | + self.rlog("distributing dtensor") |
| 105 | + tensor = self.original_tensor.to("cpu") |
| 106 | + dtensor = distribute_tensor(tensor, device_mesh, placements=self.placements) |
| 107 | + |
| 108 | + fetched_tensor = await ts.get(self.shared_key, dtensor) |
| 109 | + assert torch.equal(dtensor, fetched_tensor) |
| 110 | + |
| 111 | + return fetched_tensor, device_mesh.get_coordinate() |
| 112 | + |
| 113 | + @endpoint |
| 114 | + async def destroy_process_group(self): |
| 115 | + torch.distributed.destroy_process_group() |
| 116 | + |
| 117 | + |
| 118 | +async def dtensor_put_get_example(): |
| 119 | + """ |
| 120 | + Example demonstrating DTensor resharding between different mesh configurations. |
| 121 | + Creates a tensor of shape (size * n_put_actors, size * n_get_actors), |
| 122 | + puts it with Shard(0) and gets it with Shard(1). |
| 123 | + """ |
| 124 | + # Configuration variables |
| 125 | + size = 3 # 100 unit size => 2.4 MB Tensor Size |
| 126 | + n_put_actors = 8 |
| 127 | + n_get_actors = 8 |
| 128 | + |
| 129 | + print("Starting DTensor put/get example with:") |
| 130 | + print(f" size = {size}") |
| 131 | + print(f" n_put_actors = {n_put_actors}") |
| 132 | + print(f" n_get_actors = {n_get_actors}") |
| 133 | + |
| 134 | + # Initialize TorchStore |
| 135 | + await ts.initialize( |
| 136 | + num_storage_volumes=max(n_put_actors, n_get_actors), |
| 137 | + strategy=ts.LocalRankStrategy(), |
| 138 | + ) |
| 139 | + |
| 140 | + # Create tensor with shape (size * n_put_actors, size * n_get_actors) |
| 141 | + tensor_shape = (size * n_put_actors, size * n_get_actors) |
| 142 | + original_tensor = ( |
| 143 | + torch.arange(tensor_shape[0] * tensor_shape[1]).reshape(tensor_shape).float() |
| 144 | + ) |
| 145 | + print(f"Original tensor shape: {tensor_shape}") |
| 146 | + print(f"Original tensor:\n{original_tensor}") if size == 1 else None |
| 147 | + |
| 148 | + with tempfile.TemporaryDirectory() as filesystem_store_dir: |
| 149 | + put_mesh = None |
| 150 | + get_mesh = None |
| 151 | + try: |
| 152 | + print( |
| 153 | + f"\n--- Phase 1: Putting tensor with Shard(0) in ({n_put_actors},) mesh ---" |
| 154 | + ) |
| 155 | + # Create first mesh for putting the tensor with Shard(0) |
| 156 | + put_mesh = await spawn_actors( |
| 157 | + n_put_actors, |
| 158 | + DTensorActor, |
| 159 | + "dtensor_put_mesh", |
| 160 | + mesh_shape=(n_put_actors,), |
| 161 | + original_tensor=original_tensor, |
| 162 | + placements=[Shard(0)], # Shard along dimension 0 |
| 163 | + file_store_name=os.path.join(filesystem_store_dir, "put_test"), |
| 164 | + visible_devices=",".join(str(i) for i in range(n_put_actors)), |
| 165 | + ) |
| 166 | + |
| 167 | + # Put the tensor using the first mesh with timing |
| 168 | + put_start_time = time.perf_counter() |
| 169 | + await put_mesh.do_put.call() |
| 170 | + put_end_time = time.perf_counter() |
| 171 | + put_duration = put_end_time - put_start_time |
| 172 | + |
| 173 | + print("Successfully put tensor using first mesh") |
| 174 | + print_yellow(f"⏱️ PUT operation took: {put_duration:.4f} seconds") |
| 175 | + |
| 176 | + print( |
| 177 | + f"\n--- Phase 2: Getting tensor with Shard(1) in ({n_get_actors},) mesh ---" |
| 178 | + ) |
| 179 | + # Create second mesh for getting the tensor with Shard(1) |
| 180 | + get_mesh = await spawn_actors( |
| 181 | + n_get_actors, |
| 182 | + DTensorActor, |
| 183 | + "dtensor_get_mesh", |
| 184 | + mesh_shape=(n_get_actors,), |
| 185 | + original_tensor=torch.zeros_like(original_tensor), # Placeholder |
| 186 | + placements=[Shard(1)], # Shard along dimension 1 |
| 187 | + file_store_name=os.path.join(filesystem_store_dir, "get_test"), |
| 188 | + visible_devices=",".join( |
| 189 | + str(i) for i in range(n_put_actors, n_put_actors + n_get_actors) |
| 190 | + ), |
| 191 | + ) |
| 192 | + |
| 193 | + # Get the tensor using the second mesh with timing |
| 194 | + get_start_time = time.perf_counter() |
| 195 | + results = await get_mesh.do_get.call() |
| 196 | + get_end_time = time.perf_counter() |
| 197 | + get_duration = get_end_time - get_start_time |
| 198 | + |
| 199 | + print("Successfully retrieved tensor using second mesh") |
| 200 | + print_yellow(f"⏱️ GET operation took: {get_duration:.4f} seconds") |
| 201 | + |
| 202 | + # Print results from each rank in the get mesh |
| 203 | + for proc_info, (fetched_tensor, mesh_coord) in results: |
| 204 | + print( |
| 205 | + f"Get mesh rank {proc_info.rank} (mesh coord {mesh_coord}): " |
| 206 | + f"Retrieved tensor shape {fetched_tensor.shape}" |
| 207 | + ) |
| 208 | + print(f" with values:\n{fetched_tensor}") if size == 1 else None |
| 209 | + |
| 210 | + print("\n--- Phase 3: Verifying full tensor ---") |
| 211 | + # Also verify we can get the full tensor directly |
| 212 | + fetched_tensor = await ts.get("test_key") |
| 213 | + assert torch.equal(original_tensor, fetched_tensor) |
| 214 | + print(f"Full tensor retrieved directly:\n{fetched_tensor}") |
| 215 | + print("✓ Full tensor matches original!") |
| 216 | + |
| 217 | + # Calculate tensor size in MB |
| 218 | + total_elements = tensor_shape[0] * tensor_shape[1] |
| 219 | + tensor_size_bytes = total_elements * 4 # float32 = 4 bytes per element |
| 220 | + tensor_size_mb = tensor_size_bytes / (1024 * 1024) |
| 221 | + |
| 222 | + # Print timing summary |
| 223 | + print("\n" + "=" * 50) |
| 224 | + print_yellow("⏱️ TIMING SUMMARY:") |
| 225 | + print_yellow(f" Tensor size: {tensor_size_mb:.4f} MB ({tensor_shape})") |
| 226 | + print_yellow(f" PUT operation: {put_duration:.4f} seconds") |
| 227 | + print_yellow(f" GET operation: {get_duration:.4f} seconds") |
| 228 | + print("=" * 50) |
| 229 | + |
| 230 | + finally: |
| 231 | + # Clean up process groups and meshes |
| 232 | + if put_mesh is not None: |
| 233 | + await put_mesh.destroy_process_group.call() |
| 234 | + await put_mesh._proc_mesh.stop() |
| 235 | + if get_mesh is not None: |
| 236 | + await get_mesh.destroy_process_group.call() |
| 237 | + await get_mesh._proc_mesh.stop() |
| 238 | + await ts.shutdown() |
| 239 | + |
| 240 | + print("\nDTensor put/get example completed successfully!") |
| 241 | + |
| 242 | + |
| 243 | +if __name__ == "__main__": |
| 244 | + asyncio.run(dtensor_put_get_example()) |
0 commit comments