Skip to content

Commit 2f2902c

Browse files
kaiyuan-liLucasLLC
andauthored
optimize get() to only fetch necessary tensor slice (#36)
* basic test for get * parametrized get set perf test * sync * sync * cleanup * fmt * sync * assemble tensor slice * test * sync * partial assemble works * cleanup * fmt * sync * more assemble tests * fix resharding tests and try on ci * simplify storage volume get * enable basic resharding tests * sync * allow overlapped tensor * Updates for latest rdma + monarch (#50) * latest rdma updates from monarch * remove test code * remove test code * sync * sync * Monarch V1 Support. Necessary for direct actor to actor communications (#56) * latest rdma updates from monarch * remove test code * remove test code * working v1 * removing test code * v1 * add v1 gate * nits * linter * remove color_print --------- Co-authored-by: Lucas Pasqualin <[email protected]>
1 parent 662299f commit 2f2902c

File tree

10 files changed

+831
-99
lines changed

10 files changed

+831
-99
lines changed

.github/workflows/unit_test.yaml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,23 @@ jobs:
3535
TORCHSTORE_RDMA_ENABLED=0 \
3636
pytest tests/test_tensor_slice.py \
3737
--cov=. --cov-report=xml -vv -s
38+
- name: Run test_resharding_basic tests with coverage
39+
# TorchStore RDMA will not run on CPU-only machines
40+
# resharding tests runs for too long.
41+
# test_large_tensors.py can OOM.
42+
run: |
43+
TORCHSTORE_RDMA_ENABLED=0 \
44+
pytest tests/test_resharding_basic.py \
45+
--cov=. --cov-report=xml -vv -s
3846
- name: Run remaining tests with coverage
3947
# TorchStore RDMA will not run on CPU-only machines
4048
# resharding tests runs for too long.
4149
# test_large_tensors.py can OOM.
4250
run: |
4351
TORCHSTORE_RDMA_ENABLED=0 \
4452
pytest tests/ \
45-
--ignore=tests/test_resharding.py \
53+
--ignore=tests/test_resharding_basic.py \
54+
--ignore=tests/test_resharding_ext.py \
4655
--ignore=tests/test_tensor_slice.py \
4756
--ignore=tests/test_large_tensors.py \
4857
--cov=. --cov-report=xml --cov-append --durations=20 -vv -s

example/dtensor.py

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import asyncio
8+
import math
9+
import os
10+
import tempfile
11+
import time
12+
from logging import getLogger
13+
14+
import torch
15+
import torchstore as ts
16+
from monarch.actor import Actor, current_rank, endpoint
17+
from torch.distributed._tensor import distribute_tensor, Shard
18+
from torch.distributed.device_mesh import init_device_mesh
19+
from torchstore.logging import init_logging
20+
from torchstore.utils import spawn_actors
21+
22+
# ANSI escape codes for colored output
23+
YELLOW = "\033[93m"
24+
BOLD = "\033[1m"
25+
RESET = "\033[0m"
26+
27+
28+
def print_yellow(text):
29+
"""Print text in yellow color"""
30+
print(f"{YELLOW}{BOLD}{text}{RESET}")
31+
32+
33+
init_logging()
34+
logger = getLogger(__name__)
35+
36+
37+
class DTensorActor(Actor):
38+
"""Test class used to verify correctness of resharding across different shardings.
39+
Currently only supports a single tensor
40+
"""
41+
42+
shared_key = "test_key"
43+
44+
def __init__(
45+
self,
46+
mesh_shape,
47+
original_tensor,
48+
placements,
49+
file_store_name,
50+
visible_devices="0,1,2,3,4,5,6,7",
51+
):
52+
self.rank = current_rank().rank
53+
self.mesh_shape = mesh_shape
54+
self.world_size = math.prod(mesh_shape)
55+
self.original_tensor = original_tensor
56+
self.placements = placements
57+
self.file_store_name = file_store_name
58+
59+
# torchstore will fail without this (see LocalRankStrategy)
60+
os.environ["LOCAL_RANK"] = str(self.rank)
61+
62+
# this is only necessary for nccl, but we're not using it in this test.
63+
os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices
64+
65+
def rlog(self, msg):
66+
# TODO: set to 'info' once this is fixed in monarch (which currently is hiding logs :/)
67+
logger.info(f"rank: {self.rank} {msg}")
68+
69+
def initialize_distributed(self):
70+
self.rlog(f"Initialize process group using {self.file_store_name=} ")
71+
torch.distributed.init_process_group(
72+
backend="gloo",
73+
rank=self.rank,
74+
world_size=self.world_size,
75+
init_method=f"file://{self.file_store_name}",
76+
)
77+
78+
# this barrier is more to make sure torch.distibuted is working
79+
self.rlog("barrrer")
80+
torch.distributed.barrier()
81+
82+
@endpoint
83+
async def do_put(self):
84+
self.initialize_distributed()
85+
86+
self.rlog("Create device mesh")
87+
device_mesh = init_device_mesh("cpu", self.mesh_shape)
88+
89+
self.rlog("distributing dtensor")
90+
tensor = self.original_tensor.to("cpu")
91+
dtensor = distribute_tensor(tensor, device_mesh, placements=self.placements)
92+
93+
self.rlog(f"calling put with {dtensor=}")
94+
await ts.put(self.shared_key, dtensor)
95+
96+
@endpoint
97+
async def do_get(self):
98+
self.initialize_distributed()
99+
100+
self.rlog("Create device mesh")
101+
# TODO: nccl is giving me a weird error on process group split for 2d mesh
102+
device_mesh = init_device_mesh("cpu", self.mesh_shape)
103+
104+
self.rlog("distributing dtensor")
105+
tensor = self.original_tensor.to("cpu")
106+
dtensor = distribute_tensor(tensor, device_mesh, placements=self.placements)
107+
108+
fetched_tensor = await ts.get(self.shared_key, dtensor)
109+
assert torch.equal(dtensor, fetched_tensor)
110+
111+
return fetched_tensor, device_mesh.get_coordinate()
112+
113+
@endpoint
114+
async def destroy_process_group(self):
115+
torch.distributed.destroy_process_group()
116+
117+
118+
async def dtensor_put_get_example():
119+
"""
120+
Example demonstrating DTensor resharding between different mesh configurations.
121+
Creates a tensor of shape (size * n_put_actors, size * n_get_actors),
122+
puts it with Shard(0) and gets it with Shard(1).
123+
"""
124+
# Configuration variables
125+
size = 3 # 100 unit size => 2.4 MB Tensor Size
126+
n_put_actors = 8
127+
n_get_actors = 8
128+
129+
print("Starting DTensor put/get example with:")
130+
print(f" size = {size}")
131+
print(f" n_put_actors = {n_put_actors}")
132+
print(f" n_get_actors = {n_get_actors}")
133+
134+
# Initialize TorchStore
135+
await ts.initialize(
136+
num_storage_volumes=max(n_put_actors, n_get_actors),
137+
strategy=ts.LocalRankStrategy(),
138+
)
139+
140+
# Create tensor with shape (size * n_put_actors, size * n_get_actors)
141+
tensor_shape = (size * n_put_actors, size * n_get_actors)
142+
original_tensor = (
143+
torch.arange(tensor_shape[0] * tensor_shape[1]).reshape(tensor_shape).float()
144+
)
145+
print(f"Original tensor shape: {tensor_shape}")
146+
print(f"Original tensor:\n{original_tensor}") if size == 1 else None
147+
148+
with tempfile.TemporaryDirectory() as filesystem_store_dir:
149+
put_mesh = None
150+
get_mesh = None
151+
try:
152+
print(
153+
f"\n--- Phase 1: Putting tensor with Shard(0) in ({n_put_actors},) mesh ---"
154+
)
155+
# Create first mesh for putting the tensor with Shard(0)
156+
put_mesh = await spawn_actors(
157+
n_put_actors,
158+
DTensorActor,
159+
"dtensor_put_mesh",
160+
mesh_shape=(n_put_actors,),
161+
original_tensor=original_tensor,
162+
placements=[Shard(0)], # Shard along dimension 0
163+
file_store_name=os.path.join(filesystem_store_dir, "put_test"),
164+
visible_devices=",".join(str(i) for i in range(n_put_actors)),
165+
)
166+
167+
# Put the tensor using the first mesh with timing
168+
put_start_time = time.perf_counter()
169+
await put_mesh.do_put.call()
170+
put_end_time = time.perf_counter()
171+
put_duration = put_end_time - put_start_time
172+
173+
print("Successfully put tensor using first mesh")
174+
print_yellow(f"⏱️ PUT operation took: {put_duration:.4f} seconds")
175+
176+
print(
177+
f"\n--- Phase 2: Getting tensor with Shard(1) in ({n_get_actors},) mesh ---"
178+
)
179+
# Create second mesh for getting the tensor with Shard(1)
180+
get_mesh = await spawn_actors(
181+
n_get_actors,
182+
DTensorActor,
183+
"dtensor_get_mesh",
184+
mesh_shape=(n_get_actors,),
185+
original_tensor=torch.zeros_like(original_tensor), # Placeholder
186+
placements=[Shard(1)], # Shard along dimension 1
187+
file_store_name=os.path.join(filesystem_store_dir, "get_test"),
188+
visible_devices=",".join(
189+
str(i) for i in range(n_put_actors, n_put_actors + n_get_actors)
190+
),
191+
)
192+
193+
# Get the tensor using the second mesh with timing
194+
get_start_time = time.perf_counter()
195+
results = await get_mesh.do_get.call()
196+
get_end_time = time.perf_counter()
197+
get_duration = get_end_time - get_start_time
198+
199+
print("Successfully retrieved tensor using second mesh")
200+
print_yellow(f"⏱️ GET operation took: {get_duration:.4f} seconds")
201+
202+
# Print results from each rank in the get mesh
203+
for proc_info, (fetched_tensor, mesh_coord) in results:
204+
print(
205+
f"Get mesh rank {proc_info.rank} (mesh coord {mesh_coord}): "
206+
f"Retrieved tensor shape {fetched_tensor.shape}"
207+
)
208+
print(f" with values:\n{fetched_tensor}") if size == 1 else None
209+
210+
print("\n--- Phase 3: Verifying full tensor ---")
211+
# Also verify we can get the full tensor directly
212+
fetched_tensor = await ts.get("test_key")
213+
assert torch.equal(original_tensor, fetched_tensor)
214+
print(f"Full tensor retrieved directly:\n{fetched_tensor}")
215+
print("✓ Full tensor matches original!")
216+
217+
# Calculate tensor size in MB
218+
total_elements = tensor_shape[0] * tensor_shape[1]
219+
tensor_size_bytes = total_elements * 4 # float32 = 4 bytes per element
220+
tensor_size_mb = tensor_size_bytes / (1024 * 1024)
221+
222+
# Print timing summary
223+
print("\n" + "=" * 50)
224+
print_yellow("⏱️ TIMING SUMMARY:")
225+
print_yellow(f" Tensor size: {tensor_size_mb:.4f} MB ({tensor_shape})")
226+
print_yellow(f" PUT operation: {put_duration:.4f} seconds")
227+
print_yellow(f" GET operation: {get_duration:.4f} seconds")
228+
print("=" * 50)
229+
230+
finally:
231+
# Clean up process groups and meshes
232+
if put_mesh is not None:
233+
await put_mesh.destroy_process_group.call()
234+
await put_mesh._proc_mesh.stop()
235+
if get_mesh is not None:
236+
await get_mesh.destroy_process_group.call()
237+
await get_mesh._proc_mesh.stop()
238+
await ts.shutdown()
239+
240+
print("\nDTensor put/get example completed successfully!")
241+
242+
243+
if __name__ == "__main__":
244+
asyncio.run(dtensor_put_get_example())

tests/test_resharding.py renamed to tests/test_resharding_basic.py

Lines changed: 26 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -26,29 +26,35 @@
2626

2727

2828
@pytest.mark.parametrize(*transport_plus_strategy_params())
29+
@pytest.mark.parametrize(
30+
"put_mesh_shape,get_mesh_shape,put_sharding_dim,get_sharding_dim",
31+
[
32+
# shrink
33+
((4,), (2,), 0, 0),
34+
# grow
35+
((2,), (4,), 0, 0),
36+
],
37+
)
2938
@pytest.mark.asyncio
30-
async def test_1d_resharding(strategy_params, use_rdma):
39+
async def test_1d_resharding(
40+
strategy_params,
41+
use_rdma,
42+
put_mesh_shape,
43+
get_mesh_shape,
44+
put_sharding_dim,
45+
get_sharding_dim,
46+
):
3147
_, strategy = strategy_params
3248

33-
for put_mesh_shape, get_mesh_shape in [
34-
((4,), (2,)), # shrink
35-
((2,), (4,)), # grow
36-
]:
37-
for put_sharding_dim, get_sharding_dim in [
38-
(0, 0),
39-
(0, 1),
40-
(1, 0),
41-
(1, 1),
42-
]:
43-
# TODO: test Replicate as well, which is likely not working
44-
await _test_resharding(
45-
put_mesh_shape=put_mesh_shape,
46-
put_placements=[Shard(put_sharding_dim)],
47-
get_mesh_shape=get_mesh_shape,
48-
get_placements=[Shard(get_sharding_dim)],
49-
strategy=strategy,
50-
use_rdma=use_rdma,
51-
)
49+
# TODO: test Replicate as well, which is likely not working
50+
await _test_resharding(
51+
put_mesh_shape=put_mesh_shape,
52+
put_placements=[Shard(put_sharding_dim)],
53+
get_mesh_shape=get_mesh_shape,
54+
get_placements=[Shard(get_sharding_dim)],
55+
strategy=strategy,
56+
use_rdma=use_rdma,
57+
)
5258

5359

5460
@pytest.mark.parametrize(*transport_plus_strategy_params())
@@ -59,9 +65,6 @@ async def test_2d_to_2d_resharding(strategy_params, use_rdma):
5965
put_mesh_shape = get_mesh_shape = (2, 2)
6066
for put_sharding_dims, get_sharding_dims in [
6167
((1, 1), (0, 1)),
62-
((1, 0), (1, 0)),
63-
((0, 0), (0, 1)),
64-
((1, 1), (0, 0)),
6568
]:
6669
await _test_resharding(
6770
put_mesh_shape=put_mesh_shape,
@@ -81,10 +84,7 @@ async def test_1d_to_2d_resharding(strategy_params, use_rdma):
8184
put_mesh_shape = (4,)
8285
get_mesh_shape = (2, 2)
8386
for put_sharding_dims, get_sharding_dims in [
84-
((0,), (0, 1)),
85-
((1,), (1, 0)),
8687
((0,), (0, 0)),
87-
((1,), (1, 1)),
8888
]:
8989
await _test_resharding(
9090
put_mesh_shape=put_mesh_shape,
@@ -105,9 +105,6 @@ async def test_2d_to_1d_resharding(strategy_params, use_rdma):
105105
get_mesh_shape = (4,)
106106
for put_sharding_dims, get_sharding_dims in [
107107
((0, 0), (0,)),
108-
((1, 0), (1,)),
109-
((0, 1), (0,)),
110-
((1, 1), (1,)),
111108
]:
112109
await _test_resharding(
113110
put_mesh_shape=put_mesh_shape,
@@ -197,10 +194,6 @@ async def _test_resharding(
197194
get_placements
198195
), f"{get_mesh_shape=}, {get_placements=}"
199196

200-
logger.warn(
201-
f"Testing {put_mesh_shape=} {put_placements=} {get_mesh_shape=} {get_placements=}"
202-
)
203-
204197
original_tensor = torch.arange(8**2).reshape(
205198
8, 8
206199
) # 8x8 square, with ([[0...7],[8...15],[...]])

0 commit comments

Comments
 (0)