Skip to content

Create a new tensor arg for stride_per_key_per_rank to facilitate torch.export #2950

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 51 additions & 6 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,11 +1097,15 @@ def _maybe_compute_stride_kjt(
lengths: Optional[torch.Tensor],
offsets: Optional[torch.Tensor],
stride_per_key_per_rank: Optional[List[List[int]]],
inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None,
) -> int:
if stride is None:
if len(keys) == 0:
stride = 0
elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0:
# For VBE KJT, use inverse_indices for the batch size of the EBC output KeyedTensor.
if inverse_indices is not None and inverse_indices[1].numel() > 0:
return inverse_indices[1].shape[-1]
stride = max([sum(s) for s in stride_per_key_per_rank])
elif offsets is not None and offsets.numel() > 0:
stride = (offsets.numel() - 1) // len(keys)
Expand Down Expand Up @@ -1775,6 +1779,7 @@ def __init__(
index_per_key: Optional[Dict[str, int]] = None,
jt_dict: Optional[Dict[str, JaggedTensor]] = None,
inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None,
stride_per_key_per_rank_tensor: Optional[torch.Tensor] = None,
) -> None:
"""
This is the constructor for KeyedJaggedTensor is jit.scriptable and PT2 compatible.
Expand All @@ -1791,6 +1796,11 @@ def __init__(
self._stride_per_key_per_rank: Optional[List[List[int]]] = (
stride_per_key_per_rank
)

self._stride_per_key_per_rank_tensor: torch.Tensor = torch.empty(0)
if stride_per_key_per_rank_tensor is not None:
self._stride_per_key_per_rank_tensor = stride_per_key_per_rank_tensor

self._stride_per_key: Optional[List[int]] = stride_per_key
self._length_per_key: Optional[List[int]] = length_per_key
self._offset_per_key: Optional[List[int]] = offset_per_key
Expand Down Expand Up @@ -2165,6 +2175,7 @@ def stride(self) -> int:
self._lengths,
self._offsets,
self._stride_per_key_per_rank,
self._inverse_indices,
)
self._stride = stride
return stride
Expand All @@ -2179,7 +2190,7 @@ def stride_per_key(self) -> List[int]:
"""
stride_per_key = _maybe_compute_stride_per_key(
self._stride_per_key,
self._stride_per_key_per_rank,
self._stride_per_key_per_rank_optional,
self.stride(),
self._keys,
)
Expand All @@ -2194,7 +2205,27 @@ def stride_per_key_per_rank(self) -> List[List[int]]:
List[List[int]]: stride per key per rank of the KeyedJaggedTensor.
"""
stride_per_key_per_rank = self._stride_per_key_per_rank
return stride_per_key_per_rank if stride_per_key_per_rank is not None else []

if stride_per_key_per_rank is not None:
return stride_per_key_per_rank

if self._stride_per_key_per_rank_tensor.numel() > 0:
return self._stride_per_key_per_rank_tensor.tolist()

return []

@property
def _stride_per_key_per_rank_optional(self) -> Optional[List[List[int]]]:
if self._stride_per_key_per_rank is not None:
return self._stride_per_key_per_rank

if self._stride_per_key_per_rank_tensor.numel() > 0:
stride_per_key_per_rank: List[List[int]] = (
self._stride_per_key_per_rank_tensor.tolist()
)
return stride_per_key_per_rank

return None

def variable_stride_per_key(self) -> bool:
"""
Expand All @@ -2205,7 +2236,7 @@ def variable_stride_per_key(self) -> bool:
"""
if self._variable_stride_per_key is not None:
return self._variable_stride_per_key
return self._stride_per_key_per_rank is not None
return self._stride_per_key_per_rank_optional is not None

def inverse_indices(self) -> Tuple[List[str], torch.Tensor]:
"""
Expand Down Expand Up @@ -2370,6 +2401,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
index_per_key=self._index_per_key,
jt_dict=self._jt_dict,
inverse_indices=None,
stride_per_key_per_rank_tensor=None,
)
)
elif segment == 0:
Expand Down Expand Up @@ -2406,6 +2438,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
index_per_key=None,
jt_dict=None,
inverse_indices=None,
stride_per_key_per_rank_tensor=None,
)
)
else:
Expand Down Expand Up @@ -2452,6 +2485,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
index_per_key=None,
jt_dict=None,
inverse_indices=None,
stride_per_key_per_rank_tensor=None,
)
)
else:
Expand Down Expand Up @@ -2488,6 +2522,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
index_per_key=None,
jt_dict=None,
inverse_indices=None,
stride_per_key_per_rank_tensor=None,
)
)
start = end
Expand Down Expand Up @@ -2594,12 +2629,15 @@ def permute(
index_per_key=None,
jt_dict=None,
inverse_indices=None,
stride_per_key_per_rank_tensor=None,
)
return kjt

def flatten_lengths(self) -> "KeyedJaggedTensor":
stride_per_key_per_rank = (
self._stride_per_key_per_rank if self.variable_stride_per_key() else None
self._stride_per_key_per_rank_optional
if self.variable_stride_per_key()
else None
)
return KeyedJaggedTensor(
keys=self._keys,
Expand All @@ -2616,6 +2654,7 @@ def flatten_lengths(self) -> "KeyedJaggedTensor":
index_per_key=None,
jt_dict=None,
inverse_indices=None,
stride_per_key_per_rank_tensor=None,
)

def __getitem__(self, key: str) -> JaggedTensor:
Expand Down Expand Up @@ -2755,7 +2794,9 @@ def to(
lengths = self._lengths
offsets = self._offsets
stride_per_key_per_rank = (
self._stride_per_key_per_rank if self.variable_stride_per_key() else None
self._stride_per_key_per_rank_optional
if self.variable_stride_per_key()
else None
)
length_per_key = self._length_per_key
lengths_offset_per_key = self._lengths_offset_per_key
Expand Down Expand Up @@ -2800,6 +2841,7 @@ def to(
index_per_key=index_per_key,
jt_dict=jt_dict,
inverse_indices=inverse_indices,
stride_per_key_per_rank_tensor=None,
)

def __str__(self) -> str:
Expand Down Expand Up @@ -2831,7 +2873,9 @@ def pin_memory(self) -> "KeyedJaggedTensor":
lengths = self._lengths
offsets = self._offsets
stride_per_key_per_rank = (
self._stride_per_key_per_rank if self.variable_stride_per_key() else None
self._stride_per_key_per_rank_optional
if self.variable_stride_per_key()
else None
)
inverse_indices = self._inverse_indices
if inverse_indices is not None:
Expand All @@ -2852,6 +2896,7 @@ def pin_memory(self) -> "KeyedJaggedTensor":
index_per_key=self._index_per_key,
jt_dict=None,
inverse_indices=inverse_indices,
stride_per_key_per_rank_tensor=None,
)

def dist_labels(self) -> List[str]:
Expand Down
28 changes: 28 additions & 0 deletions torchrec/sparse/tests/test_keyed_jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,34 @@ def test_meta_device_compatibility(self) -> None:
lengths=torch.tensor([], device=torch.device("meta")),
)

def test_vbe_kjt_stride(self) -> None:
stride_per_key_per_rank = [[2], [1]]
inverse_indices = torch.tensor([[0, 1, 0], [0, 0, 0]])
kjt_1 = KeyedJaggedTensor(
keys=["f1", "f2", "f3"],
values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]),
lengths=torch.tensor([3, 3, 2]),
stride_per_key_per_rank=stride_per_key_per_rank,
inverse_indices=(["f1", "f2"], inverse_indices),
)
kjt_2 = KeyedJaggedTensor(
keys=["f1", "f2", "f3"],
values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]),
lengths=torch.tensor([3, 3, 2]),
stride_per_key_per_rank_tensor=torch.tensor(stride_per_key_per_rank),
inverse_indices=(["f1", "f2"], inverse_indices),
)

self.assertEqual(kjt_1.stride(), inverse_indices.shape[1])
self.assertEqual(kjt_1.stride_per_key_per_rank(), stride_per_key_per_rank)
self.assertEqual(
kjt_1._stride_per_key_per_rank_optional, stride_per_key_per_rank
)
self.assertEqual(kjt_2.stride_per_key_per_rank(), stride_per_key_per_rank)
self.assertEqual(
kjt_2._stride_per_key_per_rank_optional, stride_per_key_per_rank
)


class TestKeyedJaggedTensorScripting(unittest.TestCase):
def test_scriptable_forward(self) -> None:
Expand Down
Loading