@@ -202,34 +202,29 @@ def slice_scatter_decomposition(
202
202
start = get_positive_dim (start , input_tensor .shape [dim ])
203
203
if end is None : # Ensure end is int
204
204
end = dim_size
205
- end = get_positive_dim (end , input_tensor .shape [dim ])
205
+ end = (
206
+ get_positive_dim (end , input_tensor .shape [dim ]) if isinstance (end , int ) else end
207
+ )
206
208
if step is None :
207
209
step = 1
208
210
209
- src_dim = src_tensor .shape
210
211
# step == 0 is not a valid torch case
211
- # also src_dim should be equal to slice dimension
212
-
213
212
if start == 0 and end == dim_size and step == 1 :
214
213
return src_tensor
215
214
216
215
# Ensure start, end, and step are all integers
217
- assert isinstance (start , int ), "start must be an integer"
218
- assert isinstance (end , int ), "end must be an integer"
219
- assert isinstance (step , int ), "step must be an integer"
220
-
221
- cat_tensors = []
222
- index_tensor_shape = []
223
- for i , src_each_dim in enumerate (list (src_dim )):
224
- if i != dim :
225
- index_tensor_shape .append (src_each_dim )
226
- for index in range (start , end , step ):
227
- cat_tensors .append (index * torch .ones (index_tensor_shape , dtype = torch .int64 ))
228
- index_tensor = torch .stack (cat_tensors , dim )
229
- index_tensor = index_tensor .to (device_input_tensor )
230
- index_tensor_64 = index_tensor .to (torch .int64 )
231
- output_tensor = torch .scatter (input_tensor , dim , index_tensor_64 , src_tensor )
232
- return output_tensor
216
+ assert isinstance (start , (int , torch .SymInt )), "start must be an int or SymInt"
217
+ assert isinstance (end , (int , torch .SymInt )), "end must be an int or SymInt"
218
+ assert isinstance (step , (int , torch .SymInt )), "step must be an int or SymInt"
219
+
220
+ indices = torch .arange (
221
+ start , end , step , device = device_input_tensor , dtype = torch .int64
222
+ )
223
+ index_tensor = indices .view (
224
+ [- 1 if i == dim else 1 for i in range (input_tensor .dim ())]
225
+ )
226
+ index_tensor = index_tensor .expand_as (src_tensor )
227
+ return torch .scatter (input_tensor , dim , index_tensor , src_tensor )
233
228
234
229
235
230
@register_torch_trt_decomposition (
0 commit comments