diff --git a/torchrec/pt2/utils.py b/torchrec/pt2/utils.py index 55accff68..44af5ae1f 100644 --- a/torchrec/pt2/utils.py +++ b/torchrec/pt2/utils.py @@ -54,7 +54,7 @@ def kjt_for_pt2_tracing( values=values, lengths=lengths, weights=kjt.weights_or_none(), - stride_per_key_per_rank=[[stride]] * n, + stride_per_key_per_rank=torch.IntTensor([[stride]] * n, device="cpu"), inverse_indices=(kjt.keys(), inverse_indices_tensor), ) @@ -85,7 +85,7 @@ def kjt_for_pt2_tracing( lengths=lengths, weights=weights, stride=stride if not is_vb else None, - stride_per_key_per_rank=kjt.stride_per_key_per_rank() if is_vb else None, + stride_per_key_per_rank=kjt._stride_per_key_per_rank if is_vb else None, inverse_indices=inverse_indices, ) diff --git a/torchrec/schema/api_tests/test_jagged_tensor_schema.py b/torchrec/schema/api_tests/test_jagged_tensor_schema.py index eacb10d9e..d51368b12 100644 --- a/torchrec/schema/api_tests/test_jagged_tensor_schema.py +++ b/torchrec/schema/api_tests/test_jagged_tensor_schema.py @@ -9,7 +9,7 @@ import inspect import unittest -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import torch from torchrec.schema.utils import is_signature_compatible @@ -112,7 +112,9 @@ def __init__( lengths: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, stride: Optional[int] = None, - stride_per_key_per_rank: Optional[List[List[int]]] = None, + stride_per_key_per_rank: Optional[ + Union[List[List[int]], torch.IntTensor] + ] = None, # Below exposed to ensure torch.script-able stride_per_key: Optional[List[int]] = None, length_per_key: Optional[List[int]] = None, diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index c4d48ec5b..ce98da1ce 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -1098,13 +1098,15 @@ def _maybe_compute_stride_kjt( stride: Optional[int], lengths: Optional[torch.Tensor], offsets: Optional[torch.Tensor], - stride_per_key_per_rank: Optional[List[List[int]]], + stride_per_key_per_rank: Optional[torch.IntTensor], ) -> int: if stride is None: if len(keys) == 0: stride = 0 - elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0: - stride = max([sum(s) for s in stride_per_key_per_rank]) + elif ( + stride_per_key_per_rank is not None and stride_per_key_per_rank.numel() > 0 + ): + stride = int(stride_per_key_per_rank.sum(dim=1).max().item()) elif offsets is not None and offsets.numel() > 0: stride = (offsets.numel() - 1) // len(keys) elif lengths is not None: @@ -1483,8 +1485,8 @@ def _strides_from_kjt( def _kjt_empty_like(kjt: "KeyedJaggedTensor") -> "KeyedJaggedTensor": # empty like function fx wrapped, also avoids device hardcoding stride, stride_per_key_per_rank = ( - (None, kjt.stride_per_key_per_rank()) - if kjt.variable_stride_per_key() + (None, kjt._stride_per_key_per_rank) + if kjt._stride_per_key_per_rank is not None and kjt.variable_stride_per_key() else (kjt.stride(), None) ) @@ -1670,14 +1672,20 @@ def _maybe_compute_lengths_offset_per_key( def _maybe_compute_stride_per_key( stride_per_key: Optional[List[int]], - stride_per_key_per_rank: Optional[List[List[int]]], + stride_per_key_per_rank: Optional[torch.IntTensor], stride: Optional[int], keys: List[str], ) -> Optional[List[int]]: if stride_per_key is not None: return stride_per_key elif stride_per_key_per_rank is not None: - return [sum(s) for s in stride_per_key_per_rank] + if stride_per_key_per_rank.dim() != 2: + # after permute the kjt could be empty + return [] + rt: List[int] = stride_per_key_per_rank.sum(dim=1).tolist() + if not torch.jit.is_scripting() and is_torchdynamo_compiling(): + pt2_checks_all_is_size(rt) + return rt elif stride is not None: return [stride] * len(keys) else: @@ -1768,7 +1776,9 @@ def __init__( lengths: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, stride: Optional[int] = None, - stride_per_key_per_rank: Optional[List[List[int]]] = None, + stride_per_key_per_rank: Optional[ + Union[torch.IntTensor, List[List[int]]] + ] = None, # Below exposed to ensure torch.script-able stride_per_key: Optional[List[int]] = None, length_per_key: Optional[List[int]] = None, @@ -1790,8 +1800,14 @@ def __init__( self._lengths: Optional[torch.Tensor] = lengths self._offsets: Optional[torch.Tensor] = offsets self._stride: Optional[int] = stride - self._stride_per_key_per_rank: Optional[List[List[int]]] = ( - stride_per_key_per_rank + if not torch.jit.is_scripting() and is_torchdynamo_compiling(): + # in pt2.compile the stride_per_key_per_rank has to be torch.Tensor or None + # does not take List[List[int]] + assert not isinstance(stride_per_key_per_rank, list) + self._stride_per_key_per_rank: Optional[torch.IntTensor] = ( + torch.IntTensor(stride_per_key_per_rank, device="cpu") + if isinstance(stride_per_key_per_rank, list) + else stride_per_key_per_rank ) self._stride_per_key: Optional[List[int]] = stride_per_key self._length_per_key: Optional[List[int]] = length_per_key @@ -1802,6 +1818,8 @@ def __init__( self._inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = ( inverse_indices ) + # this is only needed for torch.compile case + self._pt2_stride_per_key_per_rank: Optional[List[List[int]]] = None # legacy attribute, for backward compatabilibity self._variable_stride_per_key: Optional[bool] = None @@ -1817,10 +1835,6 @@ def _init_pt2_checks(self) -> None: return if self._stride_per_key is not None: pt2_checks_all_is_size(self._stride_per_key) - if self._stride_per_key_per_rank is not None: - # pyre-ignore [16] - for s in self._stride_per_key_per_rank: - pt2_checks_all_is_size(s) @staticmethod def from_offsets_sync( @@ -2030,7 +2044,7 @@ def from_jt_dict(jt_dict: Dict[str, JaggedTensor]) -> "KeyedJaggedTensor": kjt_stride, kjt_stride_per_key_per_rank = ( (stride_per_key[0], None) if all(s == stride_per_key[0] for s in stride_per_key) - else (None, [[stride] for stride in stride_per_key]) + else (None, torch.IntTensor(stride_per_key, device="cpu").reshape(-1, 1)) ) kjt = KeyedJaggedTensor( keys=kjt_keys, @@ -2195,12 +2209,32 @@ def stride_per_key_per_rank(self) -> List[List[int]]: Returns: List[List[int]]: stride per key per rank of the KeyedJaggedTensor. """ - stride_per_key_per_rank = self._stride_per_key_per_rank - return stride_per_key_per_rank if stride_per_key_per_rank is not None else [] + # making a local reference to the class variable to make jit.script behave + _stride_per_key_per_rank = self._stride_per_key_per_rank + if ( + not torch.jit.is_scripting() + and is_torchdynamo_compiling() + and _stride_per_key_per_rank is not None + ): + if self._pt2_stride_per_key_per_rank is not None: + return self._pt2_stride_per_key_per_rank + stride_per_key_per_rank = _stride_per_key_per_rank.tolist() + for stride_per_rank in stride_per_key_per_rank: + pt2_checks_all_is_size(stride_per_rank) + self._pt2_stride_per_key_per_rank = stride_per_key_per_rank + return stride_per_key_per_rank + return ( + [] + if _stride_per_key_per_rank is None + else _stride_per_key_per_rank.tolist() + ) def variable_stride_per_key(self) -> bool: """ Returns whether the KeyedJaggedTensor has variable stride per key. + NOTE: `self._variable_stride_per_key` could be `False` when `self._stride_per_key_per_rank` + is not `None`. It might be assigned to False externally/intentionally, usually the + `self._stride_per_key_per_rank` is trivial. Returns: bool: whether the KeyedJaggedTensor has variable stride per key. @@ -2345,13 +2379,16 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]: start_offset = 0 _length_per_key = self.length_per_key() _offset_per_key = self.offset_per_key() + # use local copy/ref for self._stride_per_key_per_rank to satisfy jit.script + _stride_per_key_per_rank = self._stride_per_key_per_rank for segment in segments: end = start + segment end_offset = _offset_per_key[end] keys: List[str] = self._keys[start:end] stride_per_key_per_rank = ( - self.stride_per_key_per_rank()[start:end] + _stride_per_key_per_rank[start:end, :] if self.variable_stride_per_key() + and _stride_per_key_per_rank is not None else None ) if segment == len(self._keys): @@ -2499,17 +2536,24 @@ def permute( length_per_key = self.length_per_key() permuted_keys: List[str] = [] - permuted_stride_per_key_per_rank: List[List[int]] = [] permuted_length_per_key: List[int] = [] permuted_length_per_key_sum = 0 for index in indices: key = self.keys()[index] permuted_keys.append(key) permuted_length_per_key.append(length_per_key[index]) - if self.variable_stride_per_key(): - permuted_stride_per_key_per_rank.append( - self.stride_per_key_per_rank()[index] - ) + + stride_per_key = self._stride_per_key + permuted_stride_per_key = ( + [stride_per_key[i] for i in indices] if stride_per_key is not None else None + ) + + _stride_per_key_per_rank = self._stride_per_key_per_rank + permuted_stride_per_key_per_rank = ( + _stride_per_key_per_rank[indices, :] + if self.variable_stride_per_key() and _stride_per_key_per_rank is not None + else None + ) permuted_length_per_key_sum = sum(permuted_length_per_key) if not torch.jit.is_scripting() and is_non_strict_exporting(): @@ -2561,9 +2605,7 @@ def permute( self.weights_or_none(), permuted_length_per_key_sum, ) - stride_per_key_per_rank = ( - permuted_stride_per_key_per_rank if self.variable_stride_per_key() else None - ) + kjt = KeyedJaggedTensor( keys=permuted_keys, values=permuted_values, @@ -2571,8 +2613,8 @@ def permute( lengths=permuted_lengths.view(-1), offsets=None, stride=self._stride, - stride_per_key_per_rank=stride_per_key_per_rank, - stride_per_key=None, + stride_per_key_per_rank=permuted_stride_per_key_per_rank, + stride_per_key=permuted_stride_per_key, length_per_key=permuted_length_per_key if len(permuted_keys) > 0 else None, lengths_offset_per_key=None, offset_per_key=None, @@ -2891,7 +2933,7 @@ def dist_init( if variable_stride_per_key: assert stride_per_rank_per_key is not None - stride_per_key_per_rank_tensor: torch.Tensor = stride_per_rank_per_key.view( + stride_per_key_per_rank: torch.Tensor = stride_per_rank_per_key.view( num_workers, len(keys) ).T.cpu() @@ -2928,23 +2970,18 @@ def dist_init( weights, ) - stride_per_key_per_rank = torch.jit.annotate( - List[List[int]], stride_per_key_per_rank_tensor.tolist() - ) + if stride_per_key_per_rank.numel() == 0: + stride_per_key_per_rank = torch.zeros( + (len(keys), 1), device="cpu", dtype=torch.int64 + ) - if not stride_per_key_per_rank: - stride_per_key_per_rank = [[0]] * len(keys) if stagger > 1: - stride_per_key_per_rank_stagger: List[List[int]] = [] local_world_size = num_workers // stagger - for i in range(len(keys)): - stride_per_rank_stagger: List[int] = [] - for j in range(local_world_size): - stride_per_rank_stagger.extend( - stride_per_key_per_rank[i][j::local_world_size] - ) - stride_per_key_per_rank_stagger.append(stride_per_rank_stagger) - stride_per_key_per_rank = stride_per_key_per_rank_stagger + indices = [ + list(range(i, num_workers, local_world_size)) + for i in range(local_world_size) + ] + stride_per_key_per_rank = stride_per_key_per_rank[:, indices] kjt = KeyedJaggedTensor( keys=keys,