Skip to content

Commit 820bc7a

Browse files
tissue3facebook-github-bot
authored andcommitted
Put TensorPool OutputDist to Remote RO to avoid wrong rebatching on indexing (#3087)
Summary: Pull Request resolved: #3087 Reviewed By: yumin829928 Differential Revision: D76324557 fbshipit-source-id: 1ef856b4c9d626d189b30cef9acd5071142cb1b4
1 parent be4e6d7 commit 820bc7a

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

torchrec/distributed/tensor_pool.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,15 @@
3535
from torchrec.modules.utils import deterministic_dedup
3636

3737

38+
@torch.fx.wrap
39+
def index_select_view(
40+
output: torch.Tensor,
41+
unbucketize_permute: Optional[torch.Tensor],
42+
dim: int,
43+
) -> torch.Tensor:
44+
return output[unbucketize_permute].view(-1, dim)
45+
46+
3847
class TensorPoolAwaitable(LazyAwaitable[torch.Tensor]):
3948
def __init__(
4049
self,
@@ -441,7 +450,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
441450

442451
output = self._lookup_values_dist(lookup_list)
443452

444-
return output[unbucketize_permute].view(-1, self._dim)
453+
return index_select_view(output, unbucketize_permute, self._dim)
445454

446455
# pyre-ignore
447456
def _update_values_dist(self, ctx: ObjectPoolShardingContext, values: torch.Tensor):

0 commit comments

Comments
 (0)