Skip to content

Commit cdf8953

Browse files
jd7-trfacebook-github-bot
authored andcommitted
Create a new tensor arg for stride_per_key_per_rank to facilitate torch.export
Summary: # Context * Currently torchrec IR serializer can't handle variable batch KJT use case. * To support VBE KJT, the `stride_per_key_per_rank` field needs to be flattened as a variable in the pytree flatten spec for a VBE KJT to be unflattened correctly by`torch.export. * Currently `stride_per_key_per_rank` is a List. To flatten the `stride_per_key_per_rank` info as a variable we have to add a new tensor field for it. # Ref Differential Revision: D74207283
1 parent 988496e commit cdf8953

File tree

2 files changed

+65
-9
lines changed

2 files changed

+65
-9
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1778,6 +1778,7 @@ def __init__(
17781778
index_per_key: Optional[Dict[str, int]] = None,
17791779
jt_dict: Optional[Dict[str, JaggedTensor]] = None,
17801780
inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None,
1781+
stride_per_key_per_rank_tensor: Optional[torch.Tensor] = None,
17811782
) -> None:
17821783
"""
17831784
This is the constructor for KeyedJaggedTensor is jit.scriptable and PT2 compatible.
@@ -1794,6 +1795,11 @@ def __init__(
17941795
self._stride_per_key_per_rank: Optional[List[List[int]]] = (
17951796
stride_per_key_per_rank
17961797
)
1798+
1799+
self._stride_per_key_per_rank_tensor: torch.Tensor = torch.empty(0)
1800+
if stride_per_key_per_rank_tensor is not None:
1801+
self._stride_per_key_per_rank_tensor = stride_per_key_per_rank_tensor
1802+
17971803
self._stride_per_key: Optional[List[int]] = stride_per_key
17981804
self._length_per_key: Optional[List[int]] = length_per_key
17991805
self._offset_per_key: Optional[List[int]] = offset_per_key
@@ -2183,7 +2189,7 @@ def stride_per_key(self) -> List[int]:
21832189
"""
21842190
stride_per_key = _maybe_compute_stride_per_key(
21852191
self._stride_per_key,
2186-
self._stride_per_key_per_rank,
2192+
self._stride_per_key_per_rank_optional,
21872193
self.stride(),
21882194
self._keys,
21892195
)
@@ -2198,7 +2204,27 @@ def stride_per_key_per_rank(self) -> List[List[int]]:
21982204
List[List[int]]: stride per key per rank of the KeyedJaggedTensor.
21992205
"""
22002206
stride_per_key_per_rank = self._stride_per_key_per_rank
2201-
return stride_per_key_per_rank if stride_per_key_per_rank is not None else []
2207+
2208+
if stride_per_key_per_rank is not None:
2209+
return stride_per_key_per_rank
2210+
2211+
if self._stride_per_key_per_rank_tensor.numel() > 0:
2212+
return self._stride_per_key_per_rank_tensor.tolist()
2213+
2214+
return []
2215+
2216+
@property
2217+
def _stride_per_key_per_rank_optional(self) -> Optional[List[List[int]]]:
2218+
if self._stride_per_key_per_rank is not None:
2219+
return self._stride_per_key_per_rank
2220+
2221+
if self._stride_per_key_per_rank_tensor.numel() > 0:
2222+
stride_per_key_per_rank: List[List[int]] = (
2223+
self._stride_per_key_per_rank_tensor.tolist()
2224+
)
2225+
return stride_per_key_per_rank
2226+
2227+
return None
22022228

22032229
def variable_stride_per_key(self) -> bool:
22042230
"""
@@ -2209,7 +2235,7 @@ def variable_stride_per_key(self) -> bool:
22092235
"""
22102236
if self._variable_stride_per_key is not None:
22112237
return self._variable_stride_per_key
2212-
return self._stride_per_key_per_rank is not None
2238+
return self._stride_per_key_per_rank_optional is not None
22132239

22142240
def inverse_indices(self) -> Tuple[List[str], torch.Tensor]:
22152241
"""
@@ -2374,6 +2400,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
23742400
index_per_key=self._index_per_key,
23752401
jt_dict=self._jt_dict,
23762402
inverse_indices=None,
2403+
stride_per_key_per_rank_tensor=None,
23772404
)
23782405
)
23792406
elif segment == 0:
@@ -2410,6 +2437,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
24102437
index_per_key=None,
24112438
jt_dict=None,
24122439
inverse_indices=None,
2440+
stride_per_key_per_rank_tensor=None,
24132441
)
24142442
)
24152443
else:
@@ -2456,6 +2484,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
24562484
index_per_key=None,
24572485
jt_dict=None,
24582486
inverse_indices=None,
2487+
stride_per_key_per_rank_tensor=None,
24592488
)
24602489
)
24612490
else:
@@ -2492,6 +2521,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
24922521
index_per_key=None,
24932522
jt_dict=None,
24942523
inverse_indices=None,
2524+
stride_per_key_per_rank_tensor=None,
24952525
)
24962526
)
24972527
start = end
@@ -2598,12 +2628,15 @@ def permute(
25982628
index_per_key=None,
25992629
jt_dict=None,
26002630
inverse_indices=None,
2631+
stride_per_key_per_rank_tensor=None,
26012632
)
26022633
return kjt
26032634

26042635
def flatten_lengths(self) -> "KeyedJaggedTensor":
26052636
stride_per_key_per_rank = (
2606-
self._stride_per_key_per_rank if self.variable_stride_per_key() else None
2637+
self._stride_per_key_per_rank_optional
2638+
if self.variable_stride_per_key()
2639+
else None
26072640
)
26082641
return KeyedJaggedTensor(
26092642
keys=self._keys,
@@ -2620,6 +2653,7 @@ def flatten_lengths(self) -> "KeyedJaggedTensor":
26202653
index_per_key=None,
26212654
jt_dict=None,
26222655
inverse_indices=None,
2656+
stride_per_key_per_rank_tensor=None,
26232657
)
26242658

26252659
def __getitem__(self, key: str) -> JaggedTensor:
@@ -2759,7 +2793,9 @@ def to(
27592793
lengths = self._lengths
27602794
offsets = self._offsets
27612795
stride_per_key_per_rank = (
2762-
self._stride_per_key_per_rank if self.variable_stride_per_key() else None
2796+
self._stride_per_key_per_rank_optional
2797+
if self.variable_stride_per_key()
2798+
else None
27632799
)
27642800
length_per_key = self._length_per_key
27652801
lengths_offset_per_key = self._lengths_offset_per_key
@@ -2804,6 +2840,7 @@ def to(
28042840
index_per_key=index_per_key,
28052841
jt_dict=jt_dict,
28062842
inverse_indices=inverse_indices,
2843+
stride_per_key_per_rank_tensor=None,
28072844
)
28082845

28092846
def __str__(self) -> str:
@@ -2835,7 +2872,9 @@ def pin_memory(self) -> "KeyedJaggedTensor":
28352872
lengths = self._lengths
28362873
offsets = self._offsets
28372874
stride_per_key_per_rank = (
2838-
self._stride_per_key_per_rank if self.variable_stride_per_key() else None
2875+
self._stride_per_key_per_rank_optional
2876+
if self.variable_stride_per_key()
2877+
else None
28392878
)
28402879
inverse_indices = self._inverse_indices
28412880
if inverse_indices is not None:
@@ -2856,6 +2895,7 @@ def pin_memory(self) -> "KeyedJaggedTensor":
28562895
index_per_key=self._index_per_key,
28572896
jt_dict=None,
28582897
inverse_indices=inverse_indices,
2898+
stride_per_key_per_rank_tensor=None,
28592899
)
28602900

28612901
def dist_labels(self) -> List[str]:

torchrec/sparse/tests/test_keyed_jagged_tensor.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,16 +1018,32 @@ def test_meta_device_compatibility(self) -> None:
10181018
)
10191019

10201020
def test_vbe_kjt_stride(self) -> None:
1021+
stride_per_key_per_rank = [[2], [1]]
10211022
inverse_indices = torch.tensor([[0, 1, 0], [0, 0, 0]])
1022-
kjt = KeyedJaggedTensor(
1023+
kjt_1 = KeyedJaggedTensor(
10231024
keys=["f1", "f2", "f3"],
10241025
values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]),
10251026
lengths=torch.tensor([3, 3, 2]),
1026-
stride_per_key_per_rank=[[2], [1]],
1027+
stride_per_key_per_rank=stride_per_key_per_rank,
1028+
inverse_indices=(["f1", "f2"], inverse_indices),
1029+
)
1030+
kjt_2 = KeyedJaggedTensor(
1031+
keys=["f1", "f2", "f3"],
1032+
values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]),
1033+
lengths=torch.tensor([3, 3, 2]),
1034+
stride_per_key_per_rank_tensor=torch.tensor(stride_per_key_per_rank),
10271035
inverse_indices=(["f1", "f2"], inverse_indices),
10281036
)
10291037

1030-
self.assertEqual(kjt.stride(), inverse_indices.shape[1])
1038+
self.assertEqual(kjt_1.stride(), inverse_indices.shape[1])
1039+
self.assertEqual(kjt_1.stride_per_key_per_rank(), stride_per_key_per_rank)
1040+
self.assertEqual(
1041+
kjt_1._stride_per_key_per_rank_optional, stride_per_key_per_rank
1042+
)
1043+
self.assertEqual(kjt_2.stride_per_key_per_rank(), stride_per_key_per_rank)
1044+
self.assertEqual(
1045+
kjt_2._stride_per_key_per_rank_optional, stride_per_key_per_rank
1046+
)
10311047

10321048

10331049
class TestKeyedJaggedTensorScripting(unittest.TestCase):

0 commit comments

Comments
 (0)