diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 9cda3f9dd..e3ae21985 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -1095,6 +1095,7 @@ def _maybe_compute_stride_kjt( lengths: Optional[torch.Tensor], offsets: Optional[torch.Tensor], stride_per_key_per_rank: Optional[torch.IntTensor], + inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None, ) -> int: if stride is None: if len(keys) == 0: @@ -1102,6 +1103,10 @@ def _maybe_compute_stride_kjt( elif ( stride_per_key_per_rank is not None and stride_per_key_per_rank.numel() > 0 ): + # For VBE KJT, batch size should be based on inverse_indices when set. + if inverse_indices is not None and inverse_indices[1].numel() > 0: + return inverse_indices[1].shape[-1] + stride = int(stride_per_key_per_rank.sum(dim=1).max().item()) elif offsets is not None and offsets.numel() > 0: stride = (offsets.numel() - 1) // len(keys) @@ -2138,6 +2143,7 @@ def stride(self) -> int: self._lengths, self._offsets, self._stride_per_key_per_rank, + self._inverse_indices, ) self._stride = stride return stride diff --git a/torchrec/sparse/tests/test_keyed_jagged_tensor.py b/torchrec/sparse/tests/test_keyed_jagged_tensor.py index 1636a06bd..bac0a2c52 100644 --- a/torchrec/sparse/tests/test_keyed_jagged_tensor.py +++ b/torchrec/sparse/tests/test_keyed_jagged_tensor.py @@ -1017,6 +1017,18 @@ def test_meta_device_compatibility(self) -> None: lengths=torch.tensor([], device=torch.device("meta")), ) + def test_vbe_kjt_stride(self) -> None: + inverse_indices = torch.tensor([[0, 1, 0], [0, 0, 0]]) + kjt = KeyedJaggedTensor( + keys=["f1", "f2", "f3"], + values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]), + lengths=torch.tensor([3, 3, 2]), + stride_per_key_per_rank=[[2], [1]], + inverse_indices=(["f1", "f2"], inverse_indices), + ) + + self.assertEqual(kjt.stride(), inverse_indices.shape[-1]) + class TestKeyedJaggedTensorScripting(unittest.TestCase): def test_scriptable_forward(self) -> None: