diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 2bbe09149..17d617eeb 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -1097,11 +1097,15 @@ def _maybe_compute_stride_kjt( lengths: Optional[torch.Tensor], offsets: Optional[torch.Tensor], stride_per_key_per_rank: Optional[List[List[int]]], + inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None, ) -> int: if stride is None: if len(keys) == 0: stride = 0 elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0: + # For VBE KJT, use inverse_indices for the batch size of the EBC output KeyedTensor. + if inverse_indices is not None and inverse_indices[1].numel() > 0: + return inverse_indices[1].shape[-1] stride = max([sum(s) for s in stride_per_key_per_rank]) elif offsets is not None and offsets.numel() > 0: stride = (offsets.numel() - 1) // len(keys) @@ -1775,6 +1779,7 @@ def __init__( index_per_key: Optional[Dict[str, int]] = None, jt_dict: Optional[Dict[str, JaggedTensor]] = None, inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None, + stride_per_key_per_rank_tensor: Optional[torch.Tensor] = None, ) -> None: """ This is the constructor for KeyedJaggedTensor is jit.scriptable and PT2 compatible. @@ -1791,6 +1796,11 @@ def __init__( self._stride_per_key_per_rank: Optional[List[List[int]]] = ( stride_per_key_per_rank ) + + self._stride_per_key_per_rank_tensor: torch.Tensor = torch.empty(0) + if stride_per_key_per_rank_tensor is not None: + self._stride_per_key_per_rank_tensor = stride_per_key_per_rank_tensor + self._stride_per_key: Optional[List[int]] = stride_per_key self._length_per_key: Optional[List[int]] = length_per_key self._offset_per_key: Optional[List[int]] = offset_per_key @@ -2165,6 +2175,7 @@ def stride(self) -> int: self._lengths, self._offsets, self._stride_per_key_per_rank, + self._inverse_indices, ) self._stride = stride return stride @@ -2179,7 +2190,7 @@ def stride_per_key(self) -> List[int]: """ stride_per_key = _maybe_compute_stride_per_key( self._stride_per_key, - self._stride_per_key_per_rank, + self._stride_per_key_per_rank_optional, self.stride(), self._keys, ) @@ -2194,7 +2205,27 @@ def stride_per_key_per_rank(self) -> List[List[int]]: List[List[int]]: stride per key per rank of the KeyedJaggedTensor. """ stride_per_key_per_rank = self._stride_per_key_per_rank - return stride_per_key_per_rank if stride_per_key_per_rank is not None else [] + + if stride_per_key_per_rank is not None: + return stride_per_key_per_rank + + if self._stride_per_key_per_rank_tensor.numel() > 0: + return self._stride_per_key_per_rank_tensor.tolist() + + return [] + + @property + def _stride_per_key_per_rank_optional(self) -> Optional[List[List[int]]]: + if self._stride_per_key_per_rank is not None: + return self._stride_per_key_per_rank + + if self._stride_per_key_per_rank_tensor.numel() > 0: + stride_per_key_per_rank: List[List[int]] = ( + self._stride_per_key_per_rank_tensor.tolist() + ) + return stride_per_key_per_rank + + return None def variable_stride_per_key(self) -> bool: """ @@ -2205,7 +2236,7 @@ def variable_stride_per_key(self) -> bool: """ if self._variable_stride_per_key is not None: return self._variable_stride_per_key - return self._stride_per_key_per_rank is not None + return self._stride_per_key_per_rank_optional is not None def inverse_indices(self) -> Tuple[List[str], torch.Tensor]: """ @@ -2370,6 +2401,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]: index_per_key=self._index_per_key, jt_dict=self._jt_dict, inverse_indices=None, + stride_per_key_per_rank_tensor=None, ) ) elif segment == 0: @@ -2406,6 +2438,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]: index_per_key=None, jt_dict=None, inverse_indices=None, + stride_per_key_per_rank_tensor=None, ) ) else: @@ -2452,6 +2485,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]: index_per_key=None, jt_dict=None, inverse_indices=None, + stride_per_key_per_rank_tensor=None, ) ) else: @@ -2488,6 +2522,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]: index_per_key=None, jt_dict=None, inverse_indices=None, + stride_per_key_per_rank_tensor=None, ) ) start = end @@ -2594,12 +2629,15 @@ def permute( index_per_key=None, jt_dict=None, inverse_indices=None, + stride_per_key_per_rank_tensor=None, ) return kjt def flatten_lengths(self) -> "KeyedJaggedTensor": stride_per_key_per_rank = ( - self._stride_per_key_per_rank if self.variable_stride_per_key() else None + self._stride_per_key_per_rank_optional + if self.variable_stride_per_key() + else None ) return KeyedJaggedTensor( keys=self._keys, @@ -2616,6 +2654,7 @@ def flatten_lengths(self) -> "KeyedJaggedTensor": index_per_key=None, jt_dict=None, inverse_indices=None, + stride_per_key_per_rank_tensor=None, ) def __getitem__(self, key: str) -> JaggedTensor: @@ -2755,7 +2794,9 @@ def to( lengths = self._lengths offsets = self._offsets stride_per_key_per_rank = ( - self._stride_per_key_per_rank if self.variable_stride_per_key() else None + self._stride_per_key_per_rank_optional + if self.variable_stride_per_key() + else None ) length_per_key = self._length_per_key lengths_offset_per_key = self._lengths_offset_per_key @@ -2800,6 +2841,7 @@ def to( index_per_key=index_per_key, jt_dict=jt_dict, inverse_indices=inverse_indices, + stride_per_key_per_rank_tensor=None, ) def __str__(self) -> str: @@ -2831,7 +2873,9 @@ def pin_memory(self) -> "KeyedJaggedTensor": lengths = self._lengths offsets = self._offsets stride_per_key_per_rank = ( - self._stride_per_key_per_rank if self.variable_stride_per_key() else None + self._stride_per_key_per_rank_optional + if self.variable_stride_per_key() + else None ) inverse_indices = self._inverse_indices if inverse_indices is not None: @@ -2852,6 +2896,7 @@ def pin_memory(self) -> "KeyedJaggedTensor": index_per_key=self._index_per_key, jt_dict=None, inverse_indices=inverse_indices, + stride_per_key_per_rank_tensor=None, ) def dist_labels(self) -> List[str]: diff --git a/torchrec/sparse/tests/test_keyed_jagged_tensor.py b/torchrec/sparse/tests/test_keyed_jagged_tensor.py index 1636a06bd..581e27208 100644 --- a/torchrec/sparse/tests/test_keyed_jagged_tensor.py +++ b/torchrec/sparse/tests/test_keyed_jagged_tensor.py @@ -1017,6 +1017,34 @@ def test_meta_device_compatibility(self) -> None: lengths=torch.tensor([], device=torch.device("meta")), ) + def test_vbe_kjt_stride(self) -> None: + stride_per_key_per_rank = [[2], [1]] + inverse_indices = torch.tensor([[0, 1, 0], [0, 0, 0]]) + kjt_1 = KeyedJaggedTensor( + keys=["f1", "f2", "f3"], + values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]), + lengths=torch.tensor([3, 3, 2]), + stride_per_key_per_rank=stride_per_key_per_rank, + inverse_indices=(["f1", "f2"], inverse_indices), + ) + kjt_2 = KeyedJaggedTensor( + keys=["f1", "f2", "f3"], + values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]), + lengths=torch.tensor([3, 3, 2]), + stride_per_key_per_rank_tensor=torch.tensor(stride_per_key_per_rank), + inverse_indices=(["f1", "f2"], inverse_indices), + ) + + self.assertEqual(kjt_1.stride(), inverse_indices.shape[1]) + self.assertEqual(kjt_1.stride_per_key_per_rank(), stride_per_key_per_rank) + self.assertEqual( + kjt_1._stride_per_key_per_rank_optional, stride_per_key_per_rank + ) + self.assertEqual(kjt_2.stride_per_key_per_rank(), stride_per_key_per_rank) + self.assertEqual( + kjt_2._stride_per_key_per_rank_optional, stride_per_key_per_rank + ) + class TestKeyedJaggedTensorScripting(unittest.TestCase): def test_scriptable_forward(self) -> None: