Skip to content

Commit f590eb0

Browse files
jerryzh168meta-codesync[bot]
authored andcommitted
Update remaining callsites to not use pt2e quant API from pytorch (#1041)
Summary: X-link: https://github.com/facebookexternal/vizard/pull/18 Pull Request resolved: #1041 X-link: pytorch/executorch#16380 X-link: pytorch/ao#3535 We removed pt2e quant code from D87958849, updating some remaining callsites to use torchao or executorch Reviewed By: jainapurva Differential Revision: D89744472 fbshipit-source-id: c55135dd75e466544509293fabc384cad9bd9a78
1 parent 795a99f commit f590eb0

1 file changed

Lines changed: 6 additions & 5 deletions

File tree

torchtnt/framework/_loop_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111

1212
import torch
1313
import torch.nn as nn
14+
import torchao
1415
from torch.nn.parallel.distributed import DistributedDataParallel
1516

1617
_EXPORT_UTILS_AVAIL = True
1718
try:
18-
from torch.ao.quantization.pt2e.export_utils import model_is_exported
19+
from torchao.quantization.pt2e.export_utils import model_is_exported
1920
except Exception:
2021
_EXPORT_UTILS_AVAIL = False
2122

@@ -101,9 +102,9 @@ def _set_module_training_mode(
101102
else module
102103
):
103104
move_fn = (
104-
torch.ao.quantization.move_exported_model_to_train
105+
torchao.quantization.pt2e.move_exported_model_to_train
105106
if mode
106-
else torch.ao.quantization.move_exported_model_to_eval
107+
else torchao.quantization.pt2e.move_exported_model_to_eval
107108
)
108109
# pyre-fixme[6]: For 1st argument expected `GraphModule` but got
109110
# `Union[Module, Tensor]`.
@@ -136,9 +137,9 @@ def _reset_module_training_mode(
136137
else module
137138
):
138139
move_fn = (
139-
torch.ao.quantization.move_exported_model_to_train
140+
torchao.quantization.pt2e.move_exported_model_to_train
140141
if prior_modes[name]
141-
else torch.ao.quantization.move_exported_model_to_eval
142+
else torchao.quantization.pt2e.move_exported_model_to_eval
142143
)
143144
# pyre-fixme[6]: For 1st argument expected `GraphModule` but got
144145
# `Union[Module, Tensor]`.

0 commit comments

Comments
 (0)