diff --git a/neural_compressor/utils/utility.py b/neural_compressor/utils/utility.py index e934d580474..6b3245b6659 100644 --- a/neural_compressor/utils/utility.py +++ b/neural_compressor/utils/utility.py @@ -23,6 +23,7 @@ """ import _thread import ast +import builtins import importlib import logging import os @@ -33,6 +34,7 @@ import sys import threading import time +from collections import OrderedDict from contextlib import contextmanager from enum import Enum from functools import wraps @@ -390,16 +392,74 @@ def get_all_fp32_data(data): return [float(i) for i in data.replace("[", " ").replace("]", " ").split(" ") if i.strip() and len(i) < 32] -def get_tuning_history(tuning_history_path): +def get_tuning_history(history_path): """Get tuning history. Args: - tuning_history_path: The tuning history path, which need users to assign + history_path: The tuning history path, which need users to assign """ - with open(tuning_history_path, "rb") as f: - strategy_object = pickle.load(f) - tuning_history = strategy_object.tuning_history - return tuning_history + + class SafeUnpickler(pickle.Unpickler): + def find_class(self, module, name): + # Allowed built-in types + allowed_builtins = { + "dict", + "list", + "tuple", + "set", + "frozenset", + "str", + "bytes", + "int", + "float", + "complex", + "bool", + "NoneType", + "slice", + "type", + "object", + "bytearray", + "ellipsis", + "filter", + "map", + "range", + "reversed", + "zip", + } + if module == "builtins" and name in allowed_builtins: + return getattr(builtins, name) + + # Allow collections.OrderedDict + if module == "collections" and name == "OrderedDict": + return OrderedDict + + # Allow specific neural_compressor classes + if module.startswith("neural_compressor"): + # Validate class name exists in module + mod_path = module.replace(".__", " ") # Handle submodules + for part in mod_path.split(): + try: + __import__(part) + except ImportError: + continue + mod = sys.modules.get(module) + if mod and hasattr(mod, name): + return getattr(mod, name) + + # Allow all numpy classes + if module.startswith("numpy"): + + mod = sys.modules.get(module) + if mod and hasattr(mod, name): + return getattr(mod, name) + + # Block all other classes + raise pickle.UnpicklingError(f"Unsafe class: {module}.{name}") + + with open(history_path, "rb") as f: + strategy_object = SafeUnpickler(f).load() + tuning_history = strategy_object.tuning_history + return tuning_history def recover(fp32_model, tuning_history_path, num, **kwargs):