@@ -1098,13 +1098,15 @@ def _maybe_compute_stride_kjt(
1098
1098
stride : Optional [int ],
1099
1099
lengths : Optional [torch .Tensor ],
1100
1100
offsets : Optional [torch .Tensor ],
1101
- stride_per_key_per_rank : Optional [List [ List [ int ]] ],
1101
+ stride_per_key_per_rank : Optional [torch . IntTensor ],
1102
1102
) -> int :
1103
1103
if stride is None :
1104
1104
if len (keys ) == 0 :
1105
1105
stride = 0
1106
- elif stride_per_key_per_rank is not None and len (stride_per_key_per_rank ) > 0 :
1107
- stride = max ([sum (s ) for s in stride_per_key_per_rank ])
1106
+ elif (
1107
+ stride_per_key_per_rank is not None and stride_per_key_per_rank .numel () > 0
1108
+ ):
1109
+ stride = int (stride_per_key_per_rank .sum (dim = 1 ).max ().item ())
1108
1110
elif offsets is not None and offsets .numel () > 0 :
1109
1111
stride = (offsets .numel () - 1 ) // len (keys )
1110
1112
elif lengths is not None :
@@ -1483,8 +1485,8 @@ def _strides_from_kjt(
1483
1485
def _kjt_empty_like (kjt : "KeyedJaggedTensor" ) -> "KeyedJaggedTensor" :
1484
1486
# empty like function fx wrapped, also avoids device hardcoding
1485
1487
stride , stride_per_key_per_rank = (
1486
- (None , kjt .stride_per_key_per_rank () )
1487
- if kjt .variable_stride_per_key ()
1488
+ (None , kjt ._stride_per_key_per_rank )
1489
+ if kjt ._stride_per_key_per_rank is not None and kjt . variable_stride_per_key ()
1488
1490
else (kjt .stride (), None )
1489
1491
)
1490
1492
@@ -1670,14 +1672,20 @@ def _maybe_compute_lengths_offset_per_key(
1670
1672
1671
1673
def _maybe_compute_stride_per_key (
1672
1674
stride_per_key : Optional [List [int ]],
1673
- stride_per_key_per_rank : Optional [List [ List [ int ]] ],
1675
+ stride_per_key_per_rank : Optional [torch . IntTensor ],
1674
1676
stride : Optional [int ],
1675
1677
keys : List [str ],
1676
1678
) -> Optional [List [int ]]:
1677
1679
if stride_per_key is not None :
1678
1680
return stride_per_key
1679
1681
elif stride_per_key_per_rank is not None :
1680
- return [sum (s ) for s in stride_per_key_per_rank ]
1682
+ if stride_per_key_per_rank .dim () != 2 :
1683
+ # after permute the kjt could be empty
1684
+ return []
1685
+ rt : List [int ] = stride_per_key_per_rank .sum (dim = 1 ).tolist ()
1686
+ if not torch .jit .is_scripting () and is_torchdynamo_compiling ():
1687
+ pt2_checks_all_is_size (rt )
1688
+ return rt
1681
1689
elif stride is not None :
1682
1690
return [stride ] * len (keys )
1683
1691
else :
@@ -1768,7 +1776,9 @@ def __init__(
1768
1776
lengths : Optional [torch .Tensor ] = None ,
1769
1777
offsets : Optional [torch .Tensor ] = None ,
1770
1778
stride : Optional [int ] = None ,
1771
- stride_per_key_per_rank : Optional [List [List [int ]]] = None ,
1779
+ stride_per_key_per_rank : Optional [
1780
+ Union [torch .IntTensor , List [List [int ]]]
1781
+ ] = None ,
1772
1782
# Below exposed to ensure torch.script-able
1773
1783
stride_per_key : Optional [List [int ]] = None ,
1774
1784
length_per_key : Optional [List [int ]] = None ,
@@ -1790,8 +1800,14 @@ def __init__(
1790
1800
self ._lengths : Optional [torch .Tensor ] = lengths
1791
1801
self ._offsets : Optional [torch .Tensor ] = offsets
1792
1802
self ._stride : Optional [int ] = stride
1793
- self ._stride_per_key_per_rank : Optional [List [List [int ]]] = (
1794
- stride_per_key_per_rank
1803
+ if not torch .jit .is_scripting () and is_torchdynamo_compiling ():
1804
+ # in pt2.compile the stride_per_key_per_rank has to be torch.Tensor or None
1805
+ # does not take List[List[int]]
1806
+ assert not isinstance (stride_per_key_per_rank , list )
1807
+ self ._stride_per_key_per_rank : Optional [torch .IntTensor ] = (
1808
+ torch .IntTensor (stride_per_key_per_rank , device = "cpu" )
1809
+ if isinstance (stride_per_key_per_rank , list )
1810
+ else stride_per_key_per_rank
1795
1811
)
1796
1812
self ._stride_per_key : Optional [List [int ]] = stride_per_key
1797
1813
self ._length_per_key : Optional [List [int ]] = length_per_key
@@ -1802,6 +1818,8 @@ def __init__(
1802
1818
self ._inverse_indices : Optional [Tuple [List [str ], torch .Tensor ]] = (
1803
1819
inverse_indices
1804
1820
)
1821
+ # this is only needed for torch.compile case
1822
+ self ._pt2_stride_per_key_per_rank : Optional [List [List [int ]]] = None
1805
1823
1806
1824
# legacy attribute, for backward compatabilibity
1807
1825
self ._variable_stride_per_key : Optional [bool ] = None
@@ -1817,10 +1835,6 @@ def _init_pt2_checks(self) -> None:
1817
1835
return
1818
1836
if self ._stride_per_key is not None :
1819
1837
pt2_checks_all_is_size (self ._stride_per_key )
1820
- if self ._stride_per_key_per_rank is not None :
1821
- # pyre-ignore [16]
1822
- for s in self ._stride_per_key_per_rank :
1823
- pt2_checks_all_is_size (s )
1824
1838
1825
1839
@staticmethod
1826
1840
def from_offsets_sync (
@@ -2030,7 +2044,7 @@ def from_jt_dict(jt_dict: Dict[str, JaggedTensor]) -> "KeyedJaggedTensor":
2030
2044
kjt_stride , kjt_stride_per_key_per_rank = (
2031
2045
(stride_per_key [0 ], None )
2032
2046
if all (s == stride_per_key [0 ] for s in stride_per_key )
2033
- else (None , [[ stride ] for stride in stride_per_key ] )
2047
+ else (None , torch . IntTensor ( stride_per_key , device = "cpu" ). reshape ( - 1 , 1 ) )
2034
2048
)
2035
2049
kjt = KeyedJaggedTensor (
2036
2050
keys = kjt_keys ,
@@ -2195,12 +2209,32 @@ def stride_per_key_per_rank(self) -> List[List[int]]:
2195
2209
Returns:
2196
2210
List[List[int]]: stride per key per rank of the KeyedJaggedTensor.
2197
2211
"""
2198
- stride_per_key_per_rank = self ._stride_per_key_per_rank
2199
- return stride_per_key_per_rank if stride_per_key_per_rank is not None else []
2212
+ # making a local reference to the class variable to make jit.script behave
2213
+ _stride_per_key_per_rank = self ._stride_per_key_per_rank
2214
+ if (
2215
+ not torch .jit .is_scripting ()
2216
+ and is_torchdynamo_compiling ()
2217
+ and _stride_per_key_per_rank is not None
2218
+ ):
2219
+ if self ._pt2_stride_per_key_per_rank is not None :
2220
+ return self ._pt2_stride_per_key_per_rank
2221
+ stride_per_key_per_rank = _stride_per_key_per_rank .tolist ()
2222
+ for stride_per_rank in stride_per_key_per_rank :
2223
+ pt2_checks_all_is_size (stride_per_rank )
2224
+ self ._pt2_stride_per_key_per_rank = stride_per_key_per_rank
2225
+ return stride_per_key_per_rank
2226
+ return (
2227
+ []
2228
+ if _stride_per_key_per_rank is None
2229
+ else _stride_per_key_per_rank .tolist ()
2230
+ )
2200
2231
2201
2232
def variable_stride_per_key (self ) -> bool :
2202
2233
"""
2203
2234
Returns whether the KeyedJaggedTensor has variable stride per key.
2235
+ NOTE: `self._variable_stride_per_key` could be `False` when `self._stride_per_key_per_rank`
2236
+ is not `None`. It might be assigned to False externally/intentionally, usually the
2237
+ `self._stride_per_key_per_rank` is trivial.
2204
2238
2205
2239
Returns:
2206
2240
bool: whether the KeyedJaggedTensor has variable stride per key.
@@ -2345,13 +2379,16 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
2345
2379
start_offset = 0
2346
2380
_length_per_key = self .length_per_key ()
2347
2381
_offset_per_key = self .offset_per_key ()
2382
+ # use local copy/ref for self._stride_per_key_per_rank to satisfy jit.script
2383
+ _stride_per_key_per_rank = self ._stride_per_key_per_rank
2348
2384
for segment in segments :
2349
2385
end = start + segment
2350
2386
end_offset = _offset_per_key [end ]
2351
2387
keys : List [str ] = self ._keys [start :end ]
2352
2388
stride_per_key_per_rank = (
2353
- self . stride_per_key_per_rank () [start :end ]
2389
+ _stride_per_key_per_rank [start :end , : ]
2354
2390
if self .variable_stride_per_key ()
2391
+ and _stride_per_key_per_rank is not None
2355
2392
else None
2356
2393
)
2357
2394
if segment == len (self ._keys ):
@@ -2499,17 +2536,24 @@ def permute(
2499
2536
2500
2537
length_per_key = self .length_per_key ()
2501
2538
permuted_keys : List [str ] = []
2502
- permuted_stride_per_key_per_rank : List [List [int ]] = []
2503
2539
permuted_length_per_key : List [int ] = []
2504
2540
permuted_length_per_key_sum = 0
2505
2541
for index in indices :
2506
2542
key = self .keys ()[index ]
2507
2543
permuted_keys .append (key )
2508
2544
permuted_length_per_key .append (length_per_key [index ])
2509
- if self .variable_stride_per_key ():
2510
- permuted_stride_per_key_per_rank .append (
2511
- self .stride_per_key_per_rank ()[index ]
2512
- )
2545
+
2546
+ stride_per_key = self ._stride_per_key
2547
+ permuted_stride_per_key = (
2548
+ [stride_per_key [i ] for i in indices ] if stride_per_key is not None else None
2549
+ )
2550
+
2551
+ _stride_per_key_per_rank = self ._stride_per_key_per_rank
2552
+ permuted_stride_per_key_per_rank = (
2553
+ _stride_per_key_per_rank [indices , :]
2554
+ if self .variable_stride_per_key () and _stride_per_key_per_rank is not None
2555
+ else None
2556
+ )
2513
2557
2514
2558
permuted_length_per_key_sum = sum (permuted_length_per_key )
2515
2559
if not torch .jit .is_scripting () and is_non_strict_exporting ():
@@ -2561,18 +2605,16 @@ def permute(
2561
2605
self .weights_or_none (),
2562
2606
permuted_length_per_key_sum ,
2563
2607
)
2564
- stride_per_key_per_rank = (
2565
- permuted_stride_per_key_per_rank if self .variable_stride_per_key () else None
2566
- )
2608
+
2567
2609
kjt = KeyedJaggedTensor (
2568
2610
keys = permuted_keys ,
2569
2611
values = permuted_values ,
2570
2612
weights = permuted_weights ,
2571
2613
lengths = permuted_lengths .view (- 1 ),
2572
2614
offsets = None ,
2573
2615
stride = self ._stride ,
2574
- stride_per_key_per_rank = stride_per_key_per_rank ,
2575
- stride_per_key = None ,
2616
+ stride_per_key_per_rank = permuted_stride_per_key_per_rank ,
2617
+ stride_per_key = permuted_stride_per_key ,
2576
2618
length_per_key = permuted_length_per_key if len (permuted_keys ) > 0 else None ,
2577
2619
lengths_offset_per_key = None ,
2578
2620
offset_per_key = None ,
@@ -2891,7 +2933,7 @@ def dist_init(
2891
2933
2892
2934
if variable_stride_per_key :
2893
2935
assert stride_per_rank_per_key is not None
2894
- stride_per_key_per_rank_tensor : torch .Tensor = stride_per_rank_per_key .view (
2936
+ stride_per_key_per_rank : torch .Tensor = stride_per_rank_per_key .view (
2895
2937
num_workers , len (keys )
2896
2938
).T .cpu ()
2897
2939
@@ -2928,23 +2970,18 @@ def dist_init(
2928
2970
weights ,
2929
2971
)
2930
2972
2931
- stride_per_key_per_rank = torch .jit .annotate (
2932
- List [List [int ]], stride_per_key_per_rank_tensor .tolist ()
2933
- )
2973
+ if stride_per_key_per_rank .numel () == 0 :
2974
+ stride_per_key_per_rank = torch .zeros (
2975
+ (len (keys ), 1 ), device = "cpu" , dtype = torch .int64
2976
+ )
2934
2977
2935
- if not stride_per_key_per_rank :
2936
- stride_per_key_per_rank = [[0 ]] * len (keys )
2937
2978
if stagger > 1 :
2938
- stride_per_key_per_rank_stagger : List [List [int ]] = []
2939
2979
local_world_size = num_workers // stagger
2940
- for i in range (len (keys )):
2941
- stride_per_rank_stagger : List [int ] = []
2942
- for j in range (local_world_size ):
2943
- stride_per_rank_stagger .extend (
2944
- stride_per_key_per_rank [i ][j ::local_world_size ]
2945
- )
2946
- stride_per_key_per_rank_stagger .append (stride_per_rank_stagger )
2947
- stride_per_key_per_rank = stride_per_key_per_rank_stagger
2980
+ indices = [
2981
+ list (range (i , num_workers , local_world_size ))
2982
+ for i in range (local_world_size )
2983
+ ]
2984
+ stride_per_key_per_rank = stride_per_key_per_rank [:, indices ]
2948
2985
2949
2986
kjt = KeyedJaggedTensor (
2950
2987
keys = keys ,
0 commit comments