Skip to content

Commit 8a4378f

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
fix stride_per_key_per_rank in stagger scenario in D74366343 (#3112)
Summary: Pull Request resolved: #3112 Pull Request resolved: #3111 # context * original diff D74366343 broke cogwheel test and was reverted * the error stack P1844048578 is shown below: ``` File "/dev/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/dev/torch/nn/modules/module.py", line 1784, in _call_impl return forward_call(*args, **kwargs) File "/dev/torchrec/distributed/train_pipeline/runtime_forwards.py", line 84, in __call__ data = request.wait() File "/dev/torchrec/distributed/types.py", line 334, in wait ret: W = self._wait_impl() File "/dev/torchrec/distributed/embedding_sharding.py", line 655, in _wait_impl kjts.append(w.wait()) File "/dev/torchrec/distributed/types.py", line 334, in wait ret: W = self._wait_impl() File "/dev/torchrec/distributed/dist_data.py", line 426, in _wait_impl return type(self._input).dist_init( File "/dev/torchrec/sparse/jagged_tensor.py", line 2993, in dist_init return kjt.sync() File "/dev/torchrec/sparse/jagged_tensor.py", line 2067, in sync self.length_per_key() File "/dev/torchrec/sparse/jagged_tensor.py", line 2281, in length_per_key _length_per_key = _maybe_compute_length_per_key( File "/dev/torchrec/sparse/jagged_tensor.py", line 1192, in _maybe_compute_length_per_key _length_per_key_from_stride_per_key(lengths, stride_per_key) File "/dev/torchrec/sparse/jagged_tensor.py", line 1144, in _length_per_key_from_stride_per_key if _use_segment_sum_csr(stride_per_key): File "/dev/torchrec/sparse/jagged_tensor.py", line 1131, in _use_segment_sum_csr elements_per_segment = sum(stride_per_key) / len(stride_per_key) ZeroDivisionError: division by zero ``` * the complaint is `stride_per_key` is an empty list, which comes from the following function call: ``` stride_per_key = _maybe_compute_stride_per_key( self._stride_per_key, self._stride_per_key_per_rank, self.stride(), self._keys, ) ``` * the only place this `stride_per_key` could be empty is when the `stride_per_key_per_rank.dim() != 2` ``` def _maybe_compute_stride_per_key( stride_per_key: Optional[List[int]], stride_per_key_per_rank: Optional[torch.IntTensor], stride: Optional[int], keys: List[str], ) -> Optional[List[int]]: if stride_per_key is not None: return stride_per_key elif stride_per_key_per_rank is not None: if stride_per_key_per_rank.dim() != 2: # after permute the kjt could be empty return [] rt: List[int] = stride_per_key_per_rank.sum(dim=1).tolist() if not torch.jit.is_scripting() and is_torchdynamo_compiling(): pt2_checks_all_is_size(rt) return rt elif stride is not None: return [stride] * len(keys) else: return None ``` # the main change from D74366343 is that the `stride_per_key_per_rank` in `dist_init`: * baseline ``` if stagger > 1: stride_per_key_per_rank_stagger: List[List[int]] = [] local_world_size = num_workers // stagger for i in range(len(keys)): stride_per_rank_stagger: List[int] = [] for j in range(local_world_size): stride_per_rank_stagger.extend( stride_per_key_per_rank[i][j::local_world_size] ) stride_per_key_per_rank_stagger.append(stride_per_rank_stagger) stride_per_key_per_rank = stride_per_key_per_rank_stagger ``` * D76875546 (correct, this diff) ``` if stagger > 1: indices = torch.arange(num_workers).view(stagger, -1).T.reshape(-1) stride_per_key_per_rank = stride_per_key_per_rank[:, indices] ``` * D74366343 (incorrect, reverted) ``` if stagger > 1: local_world_size = num_workers // stagger indices = [ list(range(i, num_workers, local_world_size)) for i in range(local_world_size) ] stride_per_key_per_rank = stride_per_key_per_rank[:, indices] ``` Reviewed By: jd7-tr Differential Revision: D76903646 fbshipit-source-id: a0dda49dcd1f315012803498e304e4ca7a3d80b8
1 parent 12eb3bf commit 8a4378f

File tree

3 files changed

+83
-48
lines changed

3 files changed

+83
-48
lines changed

torchrec/pt2/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def kjt_for_pt2_tracing(
5454
values=values,
5555
lengths=lengths,
5656
weights=kjt.weights_or_none(),
57-
stride_per_key_per_rank=[[stride]] * n,
57+
stride_per_key_per_rank=torch.IntTensor([[stride]] * n, device="cpu"),
5858
inverse_indices=(kjt.keys(), inverse_indices_tensor),
5959
)
6060

@@ -85,7 +85,7 @@ def kjt_for_pt2_tracing(
8585
lengths=lengths,
8686
weights=weights,
8787
stride=stride if not is_vb else None,
88-
stride_per_key_per_rank=kjt.stride_per_key_per_rank() if is_vb else None,
88+
stride_per_key_per_rank=kjt._stride_per_key_per_rank if is_vb else None,
8989
inverse_indices=inverse_indices,
9090
)
9191

torchrec/schema/api_tests/test_jagged_tensor_schema.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import inspect
1111
import unittest
12-
from typing import Dict, List, Optional, Tuple
12+
from typing import Dict, List, Optional, Tuple, Union
1313

1414
import torch
1515
from torchrec.schema.utils import is_signature_compatible
@@ -112,7 +112,9 @@ def __init__(
112112
lengths: Optional[torch.Tensor] = None,
113113
offsets: Optional[torch.Tensor] = None,
114114
stride: Optional[int] = None,
115-
stride_per_key_per_rank: Optional[List[List[int]]] = None,
115+
stride_per_key_per_rank: Optional[
116+
Union[List[List[int]], torch.IntTensor]
117+
] = None,
116118
# Below exposed to ensure torch.script-able
117119
stride_per_key: Optional[List[int]] = None,
118120
length_per_key: Optional[List[int]] = None,

torchrec/sparse/jagged_tensor.py

Lines changed: 77 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,13 +1094,15 @@ def _maybe_compute_stride_kjt(
10941094
stride: Optional[int],
10951095
lengths: Optional[torch.Tensor],
10961096
offsets: Optional[torch.Tensor],
1097-
stride_per_key_per_rank: Optional[List[List[int]]],
1097+
stride_per_key_per_rank: Optional[torch.IntTensor],
10981098
) -> int:
10991099
if stride is None:
11001100
if len(keys) == 0:
11011101
stride = 0
1102-
elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0:
1103-
stride = max([sum(s) for s in stride_per_key_per_rank])
1102+
elif (
1103+
stride_per_key_per_rank is not None and stride_per_key_per_rank.numel() > 0
1104+
):
1105+
stride = int(stride_per_key_per_rank.sum(dim=1).max().item())
11041106
elif offsets is not None and offsets.numel() > 0:
11051107
stride = (offsets.numel() - 1) // len(keys)
11061108
elif lengths is not None:
@@ -1452,8 +1454,8 @@ def _maybe_compute_kjt_to_jt_dict(
14521454
def _kjt_empty_like(kjt: "KeyedJaggedTensor") -> "KeyedJaggedTensor":
14531455
# empty like function fx wrapped, also avoids device hardcoding
14541456
stride, stride_per_key_per_rank = (
1455-
(None, kjt.stride_per_key_per_rank())
1456-
if kjt.variable_stride_per_key()
1457+
(None, kjt._stride_per_key_per_rank)
1458+
if kjt._stride_per_key_per_rank is not None and kjt.variable_stride_per_key()
14571459
else (kjt.stride(), None)
14581460
)
14591461

@@ -1639,14 +1641,20 @@ def _maybe_compute_lengths_offset_per_key(
16391641

16401642
def _maybe_compute_stride_per_key(
16411643
stride_per_key: Optional[List[int]],
1642-
stride_per_key_per_rank: Optional[List[List[int]]],
1644+
stride_per_key_per_rank: Optional[torch.IntTensor],
16431645
stride: Optional[int],
16441646
keys: List[str],
16451647
) -> Optional[List[int]]:
16461648
if stride_per_key is not None:
16471649
return stride_per_key
16481650
elif stride_per_key_per_rank is not None:
1649-
return [sum(s) for s in stride_per_key_per_rank]
1651+
if stride_per_key_per_rank.dim() != 2:
1652+
# after permute the kjt could be empty
1653+
return []
1654+
rt: List[int] = stride_per_key_per_rank.sum(dim=1).tolist()
1655+
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
1656+
pt2_checks_all_is_size(rt)
1657+
return rt
16501658
elif stride is not None:
16511659
return [stride] * len(keys)
16521660
else:
@@ -1725,7 +1733,9 @@ def __init__(
17251733
lengths: Optional[torch.Tensor] = None,
17261734
offsets: Optional[torch.Tensor] = None,
17271735
stride: Optional[int] = None,
1728-
stride_per_key_per_rank: Optional[List[List[int]]] = None,
1736+
stride_per_key_per_rank: Optional[
1737+
Union[torch.IntTensor, List[List[int]]]
1738+
] = None,
17291739
# Below exposed to ensure torch.script-able
17301740
stride_per_key: Optional[List[int]] = None,
17311741
length_per_key: Optional[List[int]] = None,
@@ -1747,8 +1757,14 @@ def __init__(
17471757
self._lengths: Optional[torch.Tensor] = lengths
17481758
self._offsets: Optional[torch.Tensor] = offsets
17491759
self._stride: Optional[int] = stride
1750-
self._stride_per_key_per_rank: Optional[List[List[int]]] = (
1751-
stride_per_key_per_rank
1760+
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
1761+
# in pt2.compile the stride_per_key_per_rank has to be torch.Tensor or None
1762+
# does not take List[List[int]]
1763+
assert not isinstance(stride_per_key_per_rank, list)
1764+
self._stride_per_key_per_rank: Optional[torch.IntTensor] = (
1765+
torch.IntTensor(stride_per_key_per_rank, device="cpu")
1766+
if isinstance(stride_per_key_per_rank, list)
1767+
else stride_per_key_per_rank
17521768
)
17531769
self._stride_per_key: Optional[List[int]] = stride_per_key
17541770
self._length_per_key: Optional[List[int]] = length_per_key
@@ -1759,6 +1775,8 @@ def __init__(
17591775
self._inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = (
17601776
inverse_indices
17611777
)
1778+
# this is only needed for torch.compile case
1779+
self._pt2_stride_per_key_per_rank: Optional[List[List[int]]] = None
17621780

17631781
# legacy attribute, for backward compatabilibity
17641782
self._variable_stride_per_key: Optional[bool] = None
@@ -1774,10 +1792,6 @@ def _init_pt2_checks(self) -> None:
17741792
return
17751793
if self._stride_per_key is not None:
17761794
pt2_checks_all_is_size(self._stride_per_key)
1777-
if self._stride_per_key_per_rank is not None:
1778-
# pyre-ignore [16]
1779-
for s in self._stride_per_key_per_rank:
1780-
pt2_checks_all_is_size(s)
17811795

17821796
@staticmethod
17831797
def from_offsets_sync(
@@ -1987,7 +2001,7 @@ def from_jt_dict(jt_dict: Dict[str, JaggedTensor]) -> "KeyedJaggedTensor":
19872001
kjt_stride, kjt_stride_per_key_per_rank = (
19882002
(stride_per_key[0], None)
19892003
if all(s == stride_per_key[0] for s in stride_per_key)
1990-
else (None, [[stride] for stride in stride_per_key])
2004+
else (None, torch.IntTensor(stride_per_key, device="cpu").reshape(-1, 1))
19912005
)
19922006
kjt = KeyedJaggedTensor(
19932007
keys=kjt_keys,
@@ -2152,12 +2166,32 @@ def stride_per_key_per_rank(self) -> List[List[int]]:
21522166
Returns:
21532167
List[List[int]]: stride per key per rank of the KeyedJaggedTensor.
21542168
"""
2155-
stride_per_key_per_rank = self._stride_per_key_per_rank
2156-
return stride_per_key_per_rank if stride_per_key_per_rank is not None else []
2169+
# making a local reference to the class variable to make jit.script behave
2170+
_stride_per_key_per_rank = self._stride_per_key_per_rank
2171+
if (
2172+
not torch.jit.is_scripting()
2173+
and is_torchdynamo_compiling()
2174+
and _stride_per_key_per_rank is not None
2175+
):
2176+
if self._pt2_stride_per_key_per_rank is not None:
2177+
return self._pt2_stride_per_key_per_rank
2178+
stride_per_key_per_rank = _stride_per_key_per_rank.tolist()
2179+
for stride_per_rank in stride_per_key_per_rank:
2180+
pt2_checks_all_is_size(stride_per_rank)
2181+
self._pt2_stride_per_key_per_rank = stride_per_key_per_rank
2182+
return stride_per_key_per_rank
2183+
return (
2184+
[]
2185+
if _stride_per_key_per_rank is None
2186+
else _stride_per_key_per_rank.tolist()
2187+
)
21572188

21582189
def variable_stride_per_key(self) -> bool:
21592190
"""
21602191
Returns whether the KeyedJaggedTensor has variable stride per key.
2192+
NOTE: `self._variable_stride_per_key` could be `False` when `self._stride_per_key_per_rank`
2193+
is not `None`. It might be assigned to False externally/intentionally, usually the
2194+
`self._stride_per_key_per_rank` is trivial.
21612195
21622196
Returns:
21632197
bool: whether the KeyedJaggedTensor has variable stride per key.
@@ -2302,13 +2336,16 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
23022336
start_offset = 0
23032337
_length_per_key = self.length_per_key()
23042338
_offset_per_key = self.offset_per_key()
2339+
# use local copy/ref for self._stride_per_key_per_rank to satisfy jit.script
2340+
_stride_per_key_per_rank = self._stride_per_key_per_rank
23052341
for segment in segments:
23062342
end = start + segment
23072343
end_offset = _offset_per_key[end]
23082344
keys: List[str] = self._keys[start:end]
23092345
stride_per_key_per_rank = (
2310-
self.stride_per_key_per_rank()[start:end]
2346+
_stride_per_key_per_rank[start:end, :]
23112347
if self.variable_stride_per_key()
2348+
and _stride_per_key_per_rank is not None
23122349
else None
23132350
)
23142351
if segment == len(self._keys):
@@ -2456,17 +2493,24 @@ def permute(
24562493

24572494
length_per_key = self.length_per_key()
24582495
permuted_keys: List[str] = []
2459-
permuted_stride_per_key_per_rank: List[List[int]] = []
24602496
permuted_length_per_key: List[int] = []
24612497
permuted_length_per_key_sum = 0
24622498
for index in indices:
24632499
key = self.keys()[index]
24642500
permuted_keys.append(key)
24652501
permuted_length_per_key.append(length_per_key[index])
2466-
if self.variable_stride_per_key():
2467-
permuted_stride_per_key_per_rank.append(
2468-
self.stride_per_key_per_rank()[index]
2469-
)
2502+
2503+
stride_per_key = self._stride_per_key
2504+
permuted_stride_per_key = (
2505+
[stride_per_key[i] for i in indices] if stride_per_key is not None else None
2506+
)
2507+
2508+
_stride_per_key_per_rank = self._stride_per_key_per_rank
2509+
permuted_stride_per_key_per_rank = (
2510+
_stride_per_key_per_rank[indices, :]
2511+
if self.variable_stride_per_key() and _stride_per_key_per_rank is not None
2512+
else None
2513+
)
24702514

24712515
permuted_length_per_key_sum = sum(permuted_length_per_key)
24722516
if not torch.jit.is_scripting() and is_non_strict_exporting():
@@ -2518,18 +2562,16 @@ def permute(
25182562
self.weights_or_none(),
25192563
permuted_length_per_key_sum,
25202564
)
2521-
stride_per_key_per_rank = (
2522-
permuted_stride_per_key_per_rank if self.variable_stride_per_key() else None
2523-
)
2565+
25242566
kjt = KeyedJaggedTensor(
25252567
keys=permuted_keys,
25262568
values=permuted_values,
25272569
weights=permuted_weights,
25282570
lengths=permuted_lengths.view(-1),
25292571
offsets=None,
25302572
stride=self._stride,
2531-
stride_per_key_per_rank=stride_per_key_per_rank,
2532-
stride_per_key=None,
2573+
stride_per_key_per_rank=permuted_stride_per_key_per_rank,
2574+
stride_per_key=permuted_stride_per_key,
25332575
length_per_key=permuted_length_per_key if len(permuted_keys) > 0 else None,
25342576
lengths_offset_per_key=None,
25352577
offset_per_key=None,
@@ -2848,7 +2890,7 @@ def dist_init(
28482890

28492891
if variable_stride_per_key:
28502892
assert stride_per_rank_per_key is not None
2851-
stride_per_key_per_rank_tensor: torch.Tensor = stride_per_rank_per_key.view(
2893+
stride_per_key_per_rank: torch.Tensor = stride_per_rank_per_key.view(
28522894
num_workers, len(keys)
28532895
).T.cpu()
28542896

@@ -2885,23 +2927,14 @@ def dist_init(
28852927
weights,
28862928
)
28872929

2888-
stride_per_key_per_rank = torch.jit.annotate(
2889-
List[List[int]], stride_per_key_per_rank_tensor.tolist()
2890-
)
2930+
if stride_per_key_per_rank.numel() == 0:
2931+
stride_per_key_per_rank = torch.zeros(
2932+
(len(keys), 1), device="cpu", dtype=torch.int64
2933+
)
28912934

2892-
if not stride_per_key_per_rank:
2893-
stride_per_key_per_rank = [[0]] * len(keys)
28942935
if stagger > 1:
2895-
stride_per_key_per_rank_stagger: List[List[int]] = []
2896-
local_world_size = num_workers // stagger
2897-
for i in range(len(keys)):
2898-
stride_per_rank_stagger: List[int] = []
2899-
for j in range(local_world_size):
2900-
stride_per_rank_stagger.extend(
2901-
stride_per_key_per_rank[i][j::local_world_size]
2902-
)
2903-
stride_per_key_per_rank_stagger.append(stride_per_rank_stagger)
2904-
stride_per_key_per_rank = stride_per_key_per_rank_stagger
2936+
indices = torch.arange(num_workers).view(stagger, -1).T.reshape(-1)
2937+
stride_per_key_per_rank = stride_per_key_per_rank[:, indices]
29052938

29062939
kjt = KeyedJaggedTensor(
29072940
keys=keys,

0 commit comments

Comments
 (0)