Skip to content

Commit b65bd99

Browse files
authored
slice scatter support for dynamic cases (#3513)
1 parent 0621cda commit b65bd99

File tree

2 files changed

+47
-20
lines changed

2 files changed

+47
-20
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -202,34 +202,29 @@ def slice_scatter_decomposition(
202202
start = get_positive_dim(start, input_tensor.shape[dim])
203203
if end is None: # Ensure end is int
204204
end = dim_size
205-
end = get_positive_dim(end, input_tensor.shape[dim])
205+
end = (
206+
get_positive_dim(end, input_tensor.shape[dim]) if isinstance(end, int) else end
207+
)
206208
if step is None:
207209
step = 1
208210

209-
src_dim = src_tensor.shape
210211
# step == 0 is not a valid torch case
211-
# also src_dim should be equal to slice dimension
212-
213212
if start == 0 and end == dim_size and step == 1:
214213
return src_tensor
215214

216215
# Ensure start, end, and step are all integers
217-
assert isinstance(start, int), "start must be an integer"
218-
assert isinstance(end, int), "end must be an integer"
219-
assert isinstance(step, int), "step must be an integer"
220-
221-
cat_tensors = []
222-
index_tensor_shape = []
223-
for i, src_each_dim in enumerate(list(src_dim)):
224-
if i != dim:
225-
index_tensor_shape.append(src_each_dim)
226-
for index in range(start, end, step):
227-
cat_tensors.append(index * torch.ones(index_tensor_shape, dtype=torch.int64))
228-
index_tensor = torch.stack(cat_tensors, dim)
229-
index_tensor = index_tensor.to(device_input_tensor)
230-
index_tensor_64 = index_tensor.to(torch.int64)
231-
output_tensor = torch.scatter(input_tensor, dim, index_tensor_64, src_tensor)
232-
return output_tensor
216+
assert isinstance(start, (int, torch.SymInt)), "start must be an int or SymInt"
217+
assert isinstance(end, (int, torch.SymInt)), "end must be an int or SymInt"
218+
assert isinstance(step, (int, torch.SymInt)), "step must be an int or SymInt"
219+
220+
indices = torch.arange(
221+
start, end, step, device=device_input_tensor, dtype=torch.int64
222+
)
223+
index_tensor = indices.view(
224+
[-1 if i == dim else 1 for i in range(input_tensor.dim())]
225+
)
226+
index_tensor = index_tensor.expand_as(src_tensor)
227+
return torch.scatter(input_tensor, dim, index_tensor, src_tensor)
233228

234229

235230
@register_torch_trt_decomposition(

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,38 @@ def forward(self, x, src, dim, start, end, step):
812812
f"Slice_scatter TRT outputs don't match with the original model.",
813813
)
814814

815+
def test_lowering_slice_scatter_dynamic_module(self):
816+
class sliceScatter(torch.nn.Module):
817+
def __init__(self, *args, **kwargs) -> None:
818+
super().__init__(*args, **kwargs)
819+
820+
def forward(self, x, src):
821+
y = torch.ops.aten.slice_scatter(x, src, 1, 6, None, 1)
822+
return y
823+
824+
dim1 = torch.export.Dim("dim1", min=8, max=10)
825+
dynamic_shapes = {
826+
"x": [torch.export.Dim.STATIC, dim1],
827+
"src": [torch.export.Dim.STATIC, None],
828+
}
829+
inputs = (torch.zeros(8, 8).cuda(), torch.ones(8, 2).cuda())
830+
exported_program = torch.export.export(
831+
sliceScatter(), tuple(inputs), dynamic_shapes=dynamic_shapes
832+
)
833+
fx_graph = exported_program.module()
834+
inputs = [
835+
torch_tensorrt.Input(
836+
min_shape=[8, 8], opt_shape=[8, 10], max_shape=[8, 10]
837+
),
838+
torch_tensorrt.Input(min_shape=[8, 2], opt_shape=[8, 2], max_shape=[8, 2]),
839+
]
840+
torch._dynamo.reset()
841+
trt_model = torch_tensorrt.dynamo.compile(exported_program, inputs)
842+
inputs = (torch.zeros(8, 8).cuda(), torch.ones(8, 2).cuda())
843+
torch.testing.assert_close(
844+
trt_model(*inputs), fx_graph(*inputs), rtol=RTOL, atol=ATOL
845+
)
846+
815847
def test_lowering_select_scatter_dimZero_module(self):
816848
class selectScatter(torch.nn.Module):
817849
def __init__(self, *args, **kwargs) -> None:

0 commit comments

Comments
 (0)