1717
1818 def RDMABuffer (* args : Any , ** kwargs : Any ) -> Any :
1919 raise NotImplementedError (
20- "RDMABuffer is not available. This environemnt was likely not built with tensor_engine supoprt ."
20+ "RDMABuffer is not available. This environemnt was likely not built with rdma support ."
2121 )
2222
2323
@@ -27,12 +27,10 @@ def RDMABuffer(*args: Any, **kwargs: Any) -> Any:
2727 os .environ .get ("TORCHSTORE_RDMA_CHUNK_SIZE_MB" , str (1024 * 32 ))
2828)
2929
30- # assert RDMA_CHUNK_SIZE_MB <= 1024, "Monarch does not support 1gb chunks via rdma"
31-
3230
3331def rdma_available () -> bool :
3432 rdma_enabled = (
35- os .environ .get ("TORCHSTORE_RDMA_ENABLED" , "0 " ) == "1"
33+ os .environ .get ("TORCHSTORE_RDMA_ENABLED" , "1 " ) == "1"
3634 ) # TODO: enable on this build
3735 return rdma_enabled and monarch_rdma_available ()
3836
@@ -111,11 +109,13 @@ def allocate(self, tensor_like: Union[torch.Tensor, Tuple]) -> None:
111109 return
112110 elif isinstance (tensor_like , Tuple ):
113111 # we know the size of the tensor from fetching metadata
114- tensor = torch .empty (tensor_like [0 ], dtype = tensor_like [1 ])
112+ tensor = torch .empty (
113+ tensor_like [0 ], dtype = tensor_like [1 ], device = torch .device ("cpu" )
114+ )
115115 else :
116116 # we have an inplace tensor, allocate a copy
117117 assert isinstance (tensor_like , torch .Tensor )
118- tensor = torch .empty_like (tensor_like )
118+ tensor = torch .empty_like (tensor_like , device = torch . device ( "cpu" ) )
119119
120120 # store tensor meta
121121 self .shape = tensor .shape
@@ -125,7 +125,10 @@ def allocate(self, tensor_like: Union[torch.Tensor, Tuple]) -> None:
125125 self ._assert_valid_tensor (tensor )
126126
127127 byte_view_chunks = self ._create_byte_views_from_tensor (tensor )
128- self .tensor_refs = [torch .empty_like (chunk ) for chunk in byte_view_chunks ]
128+ self .tensor_refs = [
129+ torch .empty_like (chunk , device = torch .device ("cpu" ))
130+ for chunk in byte_view_chunks
131+ ]
129132 self .rdma_buffers = [RDMABuffer (chunk ) for chunk in self .tensor_refs ]
130133
131134 chunk_sizes = set ()
@@ -140,7 +143,9 @@ def update(self, other_buffer: "TransportBuffer") -> None:
140143 async def read_into (self , tensor : Optional [torch .Tensor ] = None ) -> torch .Tensor :
141144 if tensor is None :
142145 # allocate a tensor to return
143- tensor = torch .empty (self .shape , dtype = self .dtype )
146+ tensor = torch .empty (
147+ self .shape , dtype = self .dtype , device = torch .device ("cpu" )
148+ )
144149
145150 self ._assert_valid_tensor (tensor )
146151 assert self .rdma_buffers is not None
0 commit comments