@@ -1778,6 +1778,7 @@ def __init__(
1778
1778
index_per_key : Optional [Dict [str , int ]] = None ,
1779
1779
jt_dict : Optional [Dict [str , JaggedTensor ]] = None ,
1780
1780
inverse_indices : Optional [Tuple [List [str ], torch .Tensor ]] = None ,
1781
+ stride_per_key_per_rank_tensor : Optional [torch .Tensor ] = None ,
1781
1782
) -> None :
1782
1783
"""
1783
1784
This is the constructor for KeyedJaggedTensor is jit.scriptable and PT2 compatible.
@@ -1794,6 +1795,11 @@ def __init__(
1794
1795
self ._stride_per_key_per_rank : Optional [List [List [int ]]] = (
1795
1796
stride_per_key_per_rank
1796
1797
)
1798
+
1799
+ self ._stride_per_key_per_rank_tensor : torch .Tensor = torch .empty (0 )
1800
+ if stride_per_key_per_rank_tensor is not None :
1801
+ self ._stride_per_key_per_rank_tensor = stride_per_key_per_rank_tensor
1802
+
1797
1803
self ._stride_per_key : Optional [List [int ]] = stride_per_key
1798
1804
self ._length_per_key : Optional [List [int ]] = length_per_key
1799
1805
self ._offset_per_key : Optional [List [int ]] = offset_per_key
@@ -2183,7 +2189,7 @@ def stride_per_key(self) -> List[int]:
2183
2189
"""
2184
2190
stride_per_key = _maybe_compute_stride_per_key (
2185
2191
self ._stride_per_key ,
2186
- self ._stride_per_key_per_rank ,
2192
+ self ._stride_per_key_per_rank_optional ,
2187
2193
self .stride (),
2188
2194
self ._keys ,
2189
2195
)
@@ -2198,7 +2204,27 @@ def stride_per_key_per_rank(self) -> List[List[int]]:
2198
2204
List[List[int]]: stride per key per rank of the KeyedJaggedTensor.
2199
2205
"""
2200
2206
stride_per_key_per_rank = self ._stride_per_key_per_rank
2201
- return stride_per_key_per_rank if stride_per_key_per_rank is not None else []
2207
+
2208
+ if stride_per_key_per_rank is not None :
2209
+ return stride_per_key_per_rank
2210
+
2211
+ if self ._stride_per_key_per_rank_tensor .numel () > 0 :
2212
+ return self ._stride_per_key_per_rank_tensor .tolist ()
2213
+
2214
+ return []
2215
+
2216
+ @property
2217
+ def _stride_per_key_per_rank_optional (self ) -> Optional [List [List [int ]]]:
2218
+ if self ._stride_per_key_per_rank is not None :
2219
+ return self ._stride_per_key_per_rank
2220
+
2221
+ if self ._stride_per_key_per_rank_tensor .numel () > 0 :
2222
+ stride_per_key_per_rank : List [List [int ]] = (
2223
+ self ._stride_per_key_per_rank_tensor .tolist ()
2224
+ )
2225
+ return stride_per_key_per_rank
2226
+
2227
+ return None
2202
2228
2203
2229
def variable_stride_per_key (self ) -> bool :
2204
2230
"""
@@ -2209,7 +2235,7 @@ def variable_stride_per_key(self) -> bool:
2209
2235
"""
2210
2236
if self ._variable_stride_per_key is not None :
2211
2237
return self ._variable_stride_per_key
2212
- return self ._stride_per_key_per_rank is not None
2238
+ return self ._stride_per_key_per_rank_optional is not None
2213
2239
2214
2240
def inverse_indices (self ) -> Tuple [List [str ], torch .Tensor ]:
2215
2241
"""
@@ -2374,6 +2400,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
2374
2400
index_per_key = self ._index_per_key ,
2375
2401
jt_dict = self ._jt_dict ,
2376
2402
inverse_indices = None ,
2403
+ stride_per_key_per_rank_tensor = None ,
2377
2404
)
2378
2405
)
2379
2406
elif segment == 0 :
@@ -2410,6 +2437,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
2410
2437
index_per_key = None ,
2411
2438
jt_dict = None ,
2412
2439
inverse_indices = None ,
2440
+ stride_per_key_per_rank_tensor = None ,
2413
2441
)
2414
2442
)
2415
2443
else :
@@ -2456,6 +2484,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
2456
2484
index_per_key = None ,
2457
2485
jt_dict = None ,
2458
2486
inverse_indices = None ,
2487
+ stride_per_key_per_rank_tensor = None ,
2459
2488
)
2460
2489
)
2461
2490
else :
@@ -2492,6 +2521,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
2492
2521
index_per_key = None ,
2493
2522
jt_dict = None ,
2494
2523
inverse_indices = None ,
2524
+ stride_per_key_per_rank_tensor = None ,
2495
2525
)
2496
2526
)
2497
2527
start = end
@@ -2598,12 +2628,15 @@ def permute(
2598
2628
index_per_key = None ,
2599
2629
jt_dict = None ,
2600
2630
inverse_indices = None ,
2631
+ stride_per_key_per_rank_tensor = None ,
2601
2632
)
2602
2633
return kjt
2603
2634
2604
2635
def flatten_lengths (self ) -> "KeyedJaggedTensor" :
2605
2636
stride_per_key_per_rank = (
2606
- self ._stride_per_key_per_rank if self .variable_stride_per_key () else None
2637
+ self ._stride_per_key_per_rank_optional
2638
+ if self .variable_stride_per_key ()
2639
+ else None
2607
2640
)
2608
2641
return KeyedJaggedTensor (
2609
2642
keys = self ._keys ,
@@ -2620,6 +2653,7 @@ def flatten_lengths(self) -> "KeyedJaggedTensor":
2620
2653
index_per_key = None ,
2621
2654
jt_dict = None ,
2622
2655
inverse_indices = None ,
2656
+ stride_per_key_per_rank_tensor = None ,
2623
2657
)
2624
2658
2625
2659
def __getitem__ (self , key : str ) -> JaggedTensor :
@@ -2759,7 +2793,9 @@ def to(
2759
2793
lengths = self ._lengths
2760
2794
offsets = self ._offsets
2761
2795
stride_per_key_per_rank = (
2762
- self ._stride_per_key_per_rank if self .variable_stride_per_key () else None
2796
+ self ._stride_per_key_per_rank_optional
2797
+ if self .variable_stride_per_key ()
2798
+ else None
2763
2799
)
2764
2800
length_per_key = self ._length_per_key
2765
2801
lengths_offset_per_key = self ._lengths_offset_per_key
@@ -2804,6 +2840,7 @@ def to(
2804
2840
index_per_key = index_per_key ,
2805
2841
jt_dict = jt_dict ,
2806
2842
inverse_indices = inverse_indices ,
2843
+ stride_per_key_per_rank_tensor = None ,
2807
2844
)
2808
2845
2809
2846
def __str__ (self ) -> str :
@@ -2835,7 +2872,9 @@ def pin_memory(self) -> "KeyedJaggedTensor":
2835
2872
lengths = self ._lengths
2836
2873
offsets = self ._offsets
2837
2874
stride_per_key_per_rank = (
2838
- self ._stride_per_key_per_rank if self .variable_stride_per_key () else None
2875
+ self ._stride_per_key_per_rank_optional
2876
+ if self .variable_stride_per_key ()
2877
+ else None
2839
2878
)
2840
2879
inverse_indices = self ._inverse_indices
2841
2880
if inverse_indices is not None :
@@ -2856,6 +2895,7 @@ def pin_memory(self) -> "KeyedJaggedTensor":
2856
2895
index_per_key = self ._index_per_key ,
2857
2896
jt_dict = None ,
2858
2897
inverse_indices = inverse_indices ,
2898
+ stride_per_key_per_rank_tensor = None ,
2859
2899
)
2860
2900
2861
2901
def dist_labels (self ) -> List [str ]:
0 commit comments