Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions optimum/intel/openvino/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,26 @@ def __init__(
self.apply_caching = apply_caching
self.inference_result_mock = inference_result_mock
self.tensor_cache = {}
self.stateful = len(request.query_state()) > 0
self._reset_state_called = False

def collect_inputs(self, inputs):
if self.stateful and not is_nncf_version("<=", "2.19"):
if not isinstance(inputs, dict):
raise NotImplementedError("Processing of non-dict inputs for stateful models is not supported.")
inputs = inputs.copy()
inputs[nncf.Dataset.RESET_STATE_KEY] = self._reset_state_called
self._reset_state_called = False

if not self.apply_caching or not isinstance(inputs, dict):
self.collected_inputs.append(copy.deepcopy(inputs))
return

copied_inputs = {}
for k, v in inputs.items():
if isinstance(v, bool):
copied_inputs[k] = v
continue
data = v
if isinstance(data, openvino.Tensor):
data = data.data
Expand Down Expand Up @@ -223,6 +235,10 @@ def wait(self):
def get_tensor(self, name: str):
return Tensor(self.request.results[name])

def reset_state(self):
self.request.reset_state()
self._reset_state_called = True

def __getattr__(self, attr):
if attr in self.__dict__:
return getattr(self, attr)
Expand Down