Skip to content

Commit 662299f

Browse files
authored
Monarch V1 Support. Necessary for direct actor to actor communications (#56)
* latest rdma updates from monarch * remove test code * remove test code * working v1 * removing test code * v1 * add v1 gate * nits * linter
1 parent 871fa11 commit 662299f

File tree

6 files changed

+50
-12
lines changed

6 files changed

+50
-12
lines changed

tests/test_models.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,16 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from torchstore.constants import MONARCH_HOSTMESH_V1
8+
9+
if MONARCH_HOSTMESH_V1:
10+
from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport
11+
from monarch._rust_bindings.monarch_hyperactor.config import configure
12+
13+
configure(
14+
default_transport=ChannelTransport.MetaTlsWithHostname,
15+
)
16+
717
import math
818
import os
919
import tempfile

torchstore/api.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88

99
import torch
1010

11-
from monarch.actor import get_or_spawn_controller
12-
1311
import torchstore.state_dict_utils
1412
from torchstore.client import LocalClient
13+
14+
from torchstore.constants import MONARCH_HOSTMESH_V1
1515
from torchstore.controller import Controller
1616
from torchstore.storage_volume import StorageVolume
1717
from torchstore.strategy import (
@@ -21,6 +21,11 @@
2121
)
2222
from torchstore.transport.pipe import TensorSlice
2323

24+
if MONARCH_HOSTMESH_V1:
25+
from monarch._src.actor.v1.proc_mesh import get_or_spawn_controller
26+
else:
27+
from monarch.actor import get_or_spawn_controller
28+
2429

2530
# I need to keep this somewhere, so here we go
2631
DEFAULT_TORCHSTORE_NAME: str = "TorchStore"

torchstore/constants.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
9+
MONARCH_HOSTMESH_V1 = os.environ.get("MONARCH_HOSTMESH_V1", "0").lower() in (
10+
"1",
11+
"true",
12+
)

torchstore/storage_volume.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def _handle_dtensor(
188188
async def put(
189189
self, key: str, transport_buffer: TransportBuffer, request: Request
190190
) -> None:
191+
191192
if request.is_object:
192193
self.kv[key] = {"obj": request.objects}
193194
return

torchstore/transport/buffers.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
def RDMABuffer(*args: Any, **kwargs: Any) -> Any:
1919
raise NotImplementedError(
20-
"RDMABuffer is not available. This environemnt was likely not built with tensor_engine supoprt."
20+
"RDMABuffer is not available. This environemnt was likely not built with rdma support."
2121
)
2222

2323

@@ -27,12 +27,10 @@ def RDMABuffer(*args: Any, **kwargs: Any) -> Any:
2727
os.environ.get("TORCHSTORE_RDMA_CHUNK_SIZE_MB", str(1024 * 32))
2828
)
2929

30-
# assert RDMA_CHUNK_SIZE_MB <= 1024, "Monarch does not support 1gb chunks via rdma"
31-
3230

3331
def rdma_available() -> bool:
3432
rdma_enabled = (
35-
os.environ.get("TORCHSTORE_RDMA_ENABLED", "0") == "1"
33+
os.environ.get("TORCHSTORE_RDMA_ENABLED", "1") == "1"
3634
) # TODO: enable on this build
3735
return rdma_enabled and monarch_rdma_available()
3836

@@ -111,11 +109,13 @@ def allocate(self, tensor_like: Union[torch.Tensor, Tuple]) -> None:
111109
return
112110
elif isinstance(tensor_like, Tuple):
113111
# we know the size of the tensor from fetching metadata
114-
tensor = torch.empty(tensor_like[0], dtype=tensor_like[1])
112+
tensor = torch.empty(
113+
tensor_like[0], dtype=tensor_like[1], device=torch.device("cpu")
114+
)
115115
else:
116116
# we have an inplace tensor, allocate a copy
117117
assert isinstance(tensor_like, torch.Tensor)
118-
tensor = torch.empty_like(tensor_like)
118+
tensor = torch.empty_like(tensor_like, device=torch.device("cpu"))
119119

120120
# store tensor meta
121121
self.shape = tensor.shape
@@ -125,7 +125,10 @@ def allocate(self, tensor_like: Union[torch.Tensor, Tuple]) -> None:
125125
self._assert_valid_tensor(tensor)
126126

127127
byte_view_chunks = self._create_byte_views_from_tensor(tensor)
128-
self.tensor_refs = [torch.empty_like(chunk) for chunk in byte_view_chunks]
128+
self.tensor_refs = [
129+
torch.empty_like(chunk, device=torch.device("cpu"))
130+
for chunk in byte_view_chunks
131+
]
129132
self.rdma_buffers = [RDMABuffer(chunk) for chunk in self.tensor_refs]
130133

131134
chunk_sizes = set()
@@ -140,7 +143,9 @@ def update(self, other_buffer: "TransportBuffer") -> None:
140143
async def read_into(self, tensor: Optional[torch.Tensor] = None) -> torch.Tensor:
141144
if tensor is None:
142145
# allocate a tensor to return
143-
tensor = torch.empty(self.shape, dtype=self.dtype)
146+
tensor = torch.empty(
147+
self.shape, dtype=self.dtype, device=torch.device("cpu")
148+
)
144149

145150
self._assert_valid_tensor(tensor)
146151
assert self.rdma_buffers is not None

torchstore/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010

1111
import torch
1212

13-
from monarch.actor import ProcMesh, this_host
13+
from torchstore.constants import MONARCH_HOSTMESH_V1
14+
15+
if MONARCH_HOSTMESH_V1:
16+
from monarch._src.actor.v1.host_mesh import this_host
17+
else:
18+
from monarch.actor import this_host
1419

1520

1621
if TYPE_CHECKING:
@@ -29,7 +34,7 @@ async def spawn_actors(num_processes, actor_cls, name, mesh=None, **init_args):
2934
actors = mesh.spawn(f"{name}_{str(uuid.uuid4())[:8]}", actor_cls, **init_args)
3035
return actors
3136

32-
assert isinstance(mesh, ProcMesh)
37+
assert hasattr(mesh, "spawn")
3338
actors = mesh.spawn(f"{name}_{str(uuid.uuid4())[:8]}", actor_cls, **init_args)
3439

3540
return actors

0 commit comments

Comments
 (0)