Skip to content

Commit 0384103

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
convert stride_per_key_per_rank to tensor inside KJT (#2959)
Summary: Pull Request resolved: #2959 # context * this diff is part of the "variable-batch KJT refactoring" project ([doc](https://fburl.com/gdoc/svfysfai)) * previously the `stride_per_key_per_rank` variable is `List[List[int]] | None` which can't be handled correctly in PT2 IR (torch.export) * this change makes the KJT class variable `_stride_per_key_per_rank` as `torch.IntTensor | None` so it would be compatible with PT2 IR. # equivalency * to check if `self._stride_per_key_per_rank` is `None` this logic is used to differentiate variable_batch case, and should have the same behavior after this diff * to use `self._stride_per_key_per_rank` as `List[List[int]]` most of the callsite use the function to get the list: `def stride_per_key_per_rank(self) -> List[List[int]]:`, and this function is modified to covert the `torch.IntTensor` to list as ` _stride_per_key_per_rank.tolist()`, the results should be the same NOTE: currently this `self._stride_per_key_per_rank` tensor is always on CPU since it's effective the meta data of a KJT. However, ideally it should be on GPU side since it's after input_dist and we'll should avoid move it to cpu unless really need it. Reviewed By: jd7-tr Differential Revision: D74366343
1 parent 998f96f commit 0384103

File tree

3 files changed

+86
-47
lines changed

3 files changed

+86
-47
lines changed

torchrec/pt2/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def kjt_for_pt2_tracing(
5454
values=values,
5555
lengths=lengths,
5656
weights=kjt.weights_or_none(),
57-
stride_per_key_per_rank=[[stride]] * n,
57+
stride_per_key_per_rank=torch.IntTensor([[stride]] * n, device="cpu"),
5858
inverse_indices=(kjt.keys(), inverse_indices_tensor),
5959
)
6060

@@ -85,7 +85,7 @@ def kjt_for_pt2_tracing(
8585
lengths=lengths,
8686
weights=weights,
8787
stride=stride if not is_vb else None,
88-
stride_per_key_per_rank=kjt.stride_per_key_per_rank() if is_vb else None,
88+
stride_per_key_per_rank=kjt._stride_per_key_per_rank if is_vb else None,
8989
inverse_indices=inverse_indices,
9090
)
9191

torchrec/schema/api_tests/test_jagged_tensor_schema.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import inspect
1111
import unittest
12-
from typing import Dict, List, Optional, Tuple
12+
from typing import Dict, List, Optional, Tuple, Union
1313

1414
import torch
1515
from torchrec.schema.utils import is_signature_compatible
@@ -112,7 +112,9 @@ def __init__(
112112
lengths: Optional[torch.Tensor] = None,
113113
offsets: Optional[torch.Tensor] = None,
114114
stride: Optional[int] = None,
115-
stride_per_key_per_rank: Optional[List[List[int]]] = None,
115+
stride_per_key_per_rank: Optional[
116+
Union[List[List[int]], torch.IntTensor]
117+
] = None,
116118
# Below exposed to ensure torch.script-able
117119
stride_per_key: Optional[List[int]] = None,
118120
length_per_key: Optional[List[int]] = None,

torchrec/sparse/jagged_tensor.py

Lines changed: 80 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,13 +1098,15 @@ def _maybe_compute_stride_kjt(
10981098
stride: Optional[int],
10991099
lengths: Optional[torch.Tensor],
11001100
offsets: Optional[torch.Tensor],
1101-
stride_per_key_per_rank: Optional[List[List[int]]],
1101+
stride_per_key_per_rank: Optional[torch.IntTensor],
11021102
) -> int:
11031103
if stride is None:
11041104
if len(keys) == 0:
11051105
stride = 0
1106-
elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0:
1107-
stride = max([sum(s) for s in stride_per_key_per_rank])
1106+
elif (
1107+
stride_per_key_per_rank is not None and stride_per_key_per_rank.numel() > 0
1108+
):
1109+
stride = int(stride_per_key_per_rank.sum(dim=1).max().item())
11081110
elif offsets is not None and offsets.numel() > 0:
11091111
stride = (offsets.numel() - 1) // len(keys)
11101112
elif lengths is not None:
@@ -1483,8 +1485,8 @@ def _strides_from_kjt(
14831485
def _kjt_empty_like(kjt: "KeyedJaggedTensor") -> "KeyedJaggedTensor":
14841486
# empty like function fx wrapped, also avoids device hardcoding
14851487
stride, stride_per_key_per_rank = (
1486-
(None, kjt.stride_per_key_per_rank())
1487-
if kjt.variable_stride_per_key()
1488+
(None, kjt._stride_per_key_per_rank)
1489+
if kjt._stride_per_key_per_rank is not None and kjt.variable_stride_per_key()
14881490
else (kjt.stride(), None)
14891491
)
14901492

@@ -1670,14 +1672,20 @@ def _maybe_compute_lengths_offset_per_key(
16701672

16711673
def _maybe_compute_stride_per_key(
16721674
stride_per_key: Optional[List[int]],
1673-
stride_per_key_per_rank: Optional[List[List[int]]],
1675+
stride_per_key_per_rank: Optional[torch.IntTensor],
16741676
stride: Optional[int],
16751677
keys: List[str],
16761678
) -> Optional[List[int]]:
16771679
if stride_per_key is not None:
16781680
return stride_per_key
16791681
elif stride_per_key_per_rank is not None:
1680-
return [sum(s) for s in stride_per_key_per_rank]
1682+
if stride_per_key_per_rank.dim() != 2:
1683+
# after permute the kjt could be empty
1684+
return []
1685+
rt: List[int] = stride_per_key_per_rank.sum(dim=1).tolist()
1686+
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
1687+
pt2_checks_all_is_size(rt)
1688+
return rt
16811689
elif stride is not None:
16821690
return [stride] * len(keys)
16831691
else:
@@ -1768,7 +1776,9 @@ def __init__(
17681776
lengths: Optional[torch.Tensor] = None,
17691777
offsets: Optional[torch.Tensor] = None,
17701778
stride: Optional[int] = None,
1771-
stride_per_key_per_rank: Optional[List[List[int]]] = None,
1779+
stride_per_key_per_rank: Optional[
1780+
Union[torch.IntTensor, List[List[int]]]
1781+
] = None,
17721782
# Below exposed to ensure torch.script-able
17731783
stride_per_key: Optional[List[int]] = None,
17741784
length_per_key: Optional[List[int]] = None,
@@ -1790,8 +1800,14 @@ def __init__(
17901800
self._lengths: Optional[torch.Tensor] = lengths
17911801
self._offsets: Optional[torch.Tensor] = offsets
17921802
self._stride: Optional[int] = stride
1793-
self._stride_per_key_per_rank: Optional[List[List[int]]] = (
1794-
stride_per_key_per_rank
1803+
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
1804+
# in pt2.compile the stride_per_key_per_rank has to be torch.Tensor or None
1805+
# does not take List[List[int]]
1806+
assert not isinstance(stride_per_key_per_rank, list)
1807+
self._stride_per_key_per_rank: Optional[torch.IntTensor] = (
1808+
torch.IntTensor(stride_per_key_per_rank, device="cpu")
1809+
if isinstance(stride_per_key_per_rank, list)
1810+
else stride_per_key_per_rank
17951811
)
17961812
self._stride_per_key: Optional[List[int]] = stride_per_key
17971813
self._length_per_key: Optional[List[int]] = length_per_key
@@ -1802,6 +1818,8 @@ def __init__(
18021818
self._inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = (
18031819
inverse_indices
18041820
)
1821+
# this is only needed for torch.compile case
1822+
self._pt2_stride_per_key_per_rank: Optional[List[List[int]]] = None
18051823

18061824
# legacy attribute, for backward compatabilibity
18071825
self._variable_stride_per_key: Optional[bool] = None
@@ -1817,10 +1835,6 @@ def _init_pt2_checks(self) -> None:
18171835
return
18181836
if self._stride_per_key is not None:
18191837
pt2_checks_all_is_size(self._stride_per_key)
1820-
if self._stride_per_key_per_rank is not None:
1821-
# pyre-ignore [16]
1822-
for s in self._stride_per_key_per_rank:
1823-
pt2_checks_all_is_size(s)
18241838

18251839
@staticmethod
18261840
def from_offsets_sync(
@@ -2030,7 +2044,7 @@ def from_jt_dict(jt_dict: Dict[str, JaggedTensor]) -> "KeyedJaggedTensor":
20302044
kjt_stride, kjt_stride_per_key_per_rank = (
20312045
(stride_per_key[0], None)
20322046
if all(s == stride_per_key[0] for s in stride_per_key)
2033-
else (None, [[stride] for stride in stride_per_key])
2047+
else (None, torch.IntTensor(stride_per_key, device="cpu").reshape(-1, 1))
20342048
)
20352049
kjt = KeyedJaggedTensor(
20362050
keys=kjt_keys,
@@ -2195,12 +2209,32 @@ def stride_per_key_per_rank(self) -> List[List[int]]:
21952209
Returns:
21962210
List[List[int]]: stride per key per rank of the KeyedJaggedTensor.
21972211
"""
2198-
stride_per_key_per_rank = self._stride_per_key_per_rank
2199-
return stride_per_key_per_rank if stride_per_key_per_rank is not None else []
2212+
# making a local reference to the class variable to make jit.script behave
2213+
_stride_per_key_per_rank = self._stride_per_key_per_rank
2214+
if (
2215+
not torch.jit.is_scripting()
2216+
and is_torchdynamo_compiling()
2217+
and _stride_per_key_per_rank is not None
2218+
):
2219+
if self._pt2_stride_per_key_per_rank is not None:
2220+
return self._pt2_stride_per_key_per_rank
2221+
stride_per_key_per_rank = _stride_per_key_per_rank.tolist()
2222+
for stride_per_rank in stride_per_key_per_rank:
2223+
pt2_checks_all_is_size(stride_per_rank)
2224+
self._pt2_stride_per_key_per_rank = stride_per_key_per_rank
2225+
return stride_per_key_per_rank
2226+
return (
2227+
[]
2228+
if _stride_per_key_per_rank is None
2229+
else _stride_per_key_per_rank.tolist()
2230+
)
22002231

22012232
def variable_stride_per_key(self) -> bool:
22022233
"""
22032234
Returns whether the KeyedJaggedTensor has variable stride per key.
2235+
NOTE: `self._variable_stride_per_key` could be `False` when `self._stride_per_key_per_rank`
2236+
is not `None`. It might be assigned to False externally/intentionally, usually the
2237+
`self._stride_per_key_per_rank` is trivial.
22042238
22052239
Returns:
22062240
bool: whether the KeyedJaggedTensor has variable stride per key.
@@ -2345,13 +2379,16 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
23452379
start_offset = 0
23462380
_length_per_key = self.length_per_key()
23472381
_offset_per_key = self.offset_per_key()
2382+
# use local copy/ref for self._stride_per_key_per_rank to satisfy jit.script
2383+
_stride_per_key_per_rank = self._stride_per_key_per_rank
23482384
for segment in segments:
23492385
end = start + segment
23502386
end_offset = _offset_per_key[end]
23512387
keys: List[str] = self._keys[start:end]
23522388
stride_per_key_per_rank = (
2353-
self.stride_per_key_per_rank()[start:end]
2389+
_stride_per_key_per_rank[start:end, :]
23542390
if self.variable_stride_per_key()
2391+
and _stride_per_key_per_rank is not None
23552392
else None
23562393
)
23572394
if segment == len(self._keys):
@@ -2499,17 +2536,24 @@ def permute(
24992536

25002537
length_per_key = self.length_per_key()
25012538
permuted_keys: List[str] = []
2502-
permuted_stride_per_key_per_rank: List[List[int]] = []
25032539
permuted_length_per_key: List[int] = []
25042540
permuted_length_per_key_sum = 0
25052541
for index in indices:
25062542
key = self.keys()[index]
25072543
permuted_keys.append(key)
25082544
permuted_length_per_key.append(length_per_key[index])
2509-
if self.variable_stride_per_key():
2510-
permuted_stride_per_key_per_rank.append(
2511-
self.stride_per_key_per_rank()[index]
2512-
)
2545+
2546+
stride_per_key = self._stride_per_key
2547+
permuted_stride_per_key = (
2548+
[stride_per_key[i] for i in indices] if stride_per_key is not None else None
2549+
)
2550+
2551+
_stride_per_key_per_rank = self._stride_per_key_per_rank
2552+
permuted_stride_per_key_per_rank = (
2553+
_stride_per_key_per_rank[indices, :]
2554+
if self.variable_stride_per_key() and _stride_per_key_per_rank is not None
2555+
else None
2556+
)
25132557

25142558
permuted_length_per_key_sum = sum(permuted_length_per_key)
25152559
if not torch.jit.is_scripting() and is_non_strict_exporting():
@@ -2561,18 +2605,16 @@ def permute(
25612605
self.weights_or_none(),
25622606
permuted_length_per_key_sum,
25632607
)
2564-
stride_per_key_per_rank = (
2565-
permuted_stride_per_key_per_rank if self.variable_stride_per_key() else None
2566-
)
2608+
25672609
kjt = KeyedJaggedTensor(
25682610
keys=permuted_keys,
25692611
values=permuted_values,
25702612
weights=permuted_weights,
25712613
lengths=permuted_lengths.view(-1),
25722614
offsets=None,
25732615
stride=self._stride,
2574-
stride_per_key_per_rank=stride_per_key_per_rank,
2575-
stride_per_key=None,
2616+
stride_per_key_per_rank=permuted_stride_per_key_per_rank,
2617+
stride_per_key=permuted_stride_per_key,
25762618
length_per_key=permuted_length_per_key if len(permuted_keys) > 0 else None,
25772619
lengths_offset_per_key=None,
25782620
offset_per_key=None,
@@ -2891,7 +2933,7 @@ def dist_init(
28912933

28922934
if variable_stride_per_key:
28932935
assert stride_per_rank_per_key is not None
2894-
stride_per_key_per_rank_tensor: torch.Tensor = stride_per_rank_per_key.view(
2936+
stride_per_key_per_rank: torch.Tensor = stride_per_rank_per_key.view(
28952937
num_workers, len(keys)
28962938
).T.cpu()
28972939

@@ -2928,23 +2970,18 @@ def dist_init(
29282970
weights,
29292971
)
29302972

2931-
stride_per_key_per_rank = torch.jit.annotate(
2932-
List[List[int]], stride_per_key_per_rank_tensor.tolist()
2933-
)
2973+
if stride_per_key_per_rank.numel() == 0:
2974+
stride_per_key_per_rank = torch.zeros(
2975+
(len(keys), 1), device="cpu", dtype=torch.int64
2976+
)
29342977

2935-
if not stride_per_key_per_rank:
2936-
stride_per_key_per_rank = [[0]] * len(keys)
29372978
if stagger > 1:
2938-
stride_per_key_per_rank_stagger: List[List[int]] = []
29392979
local_world_size = num_workers // stagger
2940-
for i in range(len(keys)):
2941-
stride_per_rank_stagger: List[int] = []
2942-
for j in range(local_world_size):
2943-
stride_per_rank_stagger.extend(
2944-
stride_per_key_per_rank[i][j::local_world_size]
2945-
)
2946-
stride_per_key_per_rank_stagger.append(stride_per_rank_stagger)
2947-
stride_per_key_per_rank = stride_per_key_per_rank_stagger
2980+
indices = [
2981+
list(range(i, num_workers, local_world_size))
2982+
for i in range(local_world_size)
2983+
]
2984+
stride_per_key_per_rank = stride_per_key_per_rank[:, indices]
29482985

29492986
kjt = KeyedJaggedTensor(
29502987
keys=keys,

0 commit comments

Comments
 (0)