Skip to content

Commit b6cf8e5

Browse files
Lluo/cherry pick 3620 (#3658)
Co-authored-by: cehongwang <[email protected]>
1 parent 527880e commit b6cf8e5

File tree

4 files changed

+23
-20
lines changed

4 files changed

+23
-20
lines changed

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
from torch_tensorrt.dynamo.utils import (
4343
CPU_DEVICE,
4444
check_module_output,
45-
deallocate_module,
4645
get_model_device,
4746
get_torch_inputs,
4847
to_torch_device,
@@ -484,7 +483,6 @@ def refit_module_weights(
484483
settings=settings,
485484
weight_name_map=None,
486485
)
487-
deallocate_module(new_submodule)
488486

489487
# clear EXCLUDE_WEIGHTS flag
490488
serialization_config = engine.create_serialization_config()
@@ -507,8 +505,6 @@ def refit_module_weights(
507505
gc.collect()
508506
torch.cuda.empty_cache()
509507

510-
deallocate_module(new_partitioned_module)
511-
512508
if verify_output and arg_inputs is not None:
513509
new_gm.to(to_torch_device(settings.device))
514510
if check_module_output(

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from torch._dynamo.backends.common import aot_autograd
1111
from torch._dynamo.utils import detect_fake_mode
1212
from torch._functorch.aot_autograd import aot_export_joint_simple
13-
from torch.distributed.tensor import DTensor
1413
from torch_tensorrt.dynamo import CompilationSettings
1514
from torch_tensorrt.dynamo._compiler import compile_module
1615
from torch_tensorrt.dynamo.lowering import (
@@ -89,6 +88,11 @@ def aot_torch_tensorrt_aten_backend(
8988
logger.warning(
9089
"It is recommended to run the model with use_distributed_mode_trace = True since there are distributed tensors in the input which is not supported in aot_export_joint_simple"
9190
)
91+
92+
if settings.offload_module_to_cpu:
93+
logger.warning(
94+
"The offload_module_to_cpu option is set, but it is being ignored since the torch_compile backend does not support this feature"
95+
)
9296
return _pretraced_backend(gm, sample_inputs, settings, engine_cache)
9397

9498

tests/py/dynamo/models/test_export_serde.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -321,11 +321,12 @@ def test_resnet18_cpu_offload(ir):
321321

322322
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
323323
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
324-
assertions.assertTrue(
325-
get_model_device(model).type == "cpu",
326-
msg="Model should be offloaded to CPU",
327-
)
328-
model.cuda()
324+
if ir == "dynamo":
325+
assertions.assertTrue(
326+
get_model_device(model).type == "cpu",
327+
msg="Model should be offloaded to CPU",
328+
)
329+
model.cuda()
329330
torchtrt.save(trt_module, trt_ep_path)
330331

331332
deser_trt_module = torchtrt.load(trt_ep_path).module()

tests/py/dynamo/models/test_models.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,12 @@ def test_resnet18_cpu_offload(ir):
7979
}
8080

8181
trt_mod = torchtrt.compile(model, **compile_spec)
82-
assertions.assertTrue(
83-
get_model_device(model).type == "cpu",
84-
msg="Model should be offloaded to CPU",
85-
)
86-
model.cuda()
82+
if ir == "dynamo":
83+
assertions.assertTrue(
84+
get_model_device(model).type == "cpu",
85+
msg="Model should be offloaded to CPU",
86+
)
87+
model.cuda()
8788
cos_sim = cosine_similarity(model(input), trt_mod(input))
8889
assertions.assertTrue(
8990
cos_sim > COSINE_THRESHOLD,
@@ -286,11 +287,12 @@ def test_bert_base_uncased_cpu_offload(ir):
286287
"offload_module_to_cpu": True,
287288
}
288289
trt_mod = torchtrt.compile(model, **compile_spec)
289-
assertions.assertTrue(
290-
get_model_device(model).type == "cpu",
291-
msg="Model should be offloaded to CPU",
292-
)
293-
model.cuda()
290+
if ir == "dynamo":
291+
assertions.assertTrue(
292+
get_model_device(model).type == "cpu",
293+
msg="Model should be offloaded to CPU",
294+
)
295+
model.cuda()
294296

295297
model_outputs = model(input, input2)
296298
trt_model_outputs = trt_mod(input, input2)

0 commit comments

Comments
 (0)