diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index 629990be9..eeac7edc8 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -13,7 +13,7 @@ from transformers.utils.fx import HFTracer from llmcompressor.modifiers.utils.hooks import HooksMixin -from llmcompressor.utils.helpers import calibration_forward_context, preserve_attr +from llmcompressor.utils.helpers import calibration_forward_context, patch_attr __all__ = ["trace_subgraphs", "Subgraph"] @@ -135,15 +135,14 @@ def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool: def trace(self, root: Union[Module, Callable], *args, **kwargs) -> Graph: if isinstance(root, Module): - with preserve_attr(type(root), "forward"): - # due to a bug in Tracer.create_args_for_root (_patch_function), - # we must unwrap function wrappers prior to tracing, for example - # the `deprecate_kwarg` by transformers which wraps forward - - # we override the class method because the - # class method is the one being traced - type(root).forward = inspect.unwrap(type(root).forward) - + # due to a bug in Tracer.create_args_for_root (_patch_function), + # we must unwrap function wrappers prior to tracing, for example + # the `deprecate_kwarg` by transformers which wraps forward + unwrapped_forward = inspect.unwrap(type(root).forward) + + # we override the class method because the + # class method is the one being traced + with patch_attr(type(root), "forward", unwrapped_forward): return super().trace(root, *args, **kwargs) else: diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index 75fad8311..191306c32 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -67,7 +67,7 @@ "DisableQuantization", "eval_context", "calibration_forward_context", - "preserve_attr", + "patch_attr", ] @@ -1132,9 +1132,39 @@ def calibration_forward_context(model: PreTrainedModel): @contextlib.contextmanager -def preserve_attr(base: object, attr: str): - value = getattr(base, attr) +def patch_attr(base: object, attr: str, *args, **kwargs): + """ + :param base: object which has the attribute to patch + :param attr: name of the the attribute to patch + :param value: optional value used to replace attribute value. If none is provided, + then the value is not replaced and the original value is restored upon exit + + Usage: + >>> from types import SimpleNamespace + >>> obj = SimpleNamespace() + >>> with patch_attr(obj, "attribute", "value"): + ... assert obj.attribute == "value" + >>> assert not hasattr(obj, "attribute") + """ + # get value to patch with (handles `None` case) + has_patched_value = len(args) > 0 or "value" in kwargs + if len(args) > 0 and "value" in kwargs: + raise TypeError("Given two arguments for `value`") + patched_value = args[0] if len(args) > 0 else kwargs.pop("value", None) + + # check if attribute already exists + has_original_value = hasattr(base, attr) + if has_original_value: + value = getattr(base, attr) + + # override value + if has_patched_value: + setattr(base, attr, patched_value) try: yield finally: - setattr(base, attr, value) + # restore value + if has_original_value: + setattr(base, attr, value) + elif hasattr(base, attr): + delattr(base, attr) diff --git a/tests/llmcompressor/utils/test_helpers.py b/tests/llmcompressor/utils/test_helpers.py index c1975f9d5..013a35817 100644 --- a/tests/llmcompressor/utils/test_helpers.py +++ b/tests/llmcompressor/utils/test_helpers.py @@ -12,10 +12,12 @@ getattr_chain, interpolate, parse_kwarg_tuples, + patch_attr, validate_str_iterable, ) +@pytest.mark.unit @pytest.mark.parametrize( "test_list,output", [ @@ -30,6 +32,7 @@ def test_flatten_iterable(test_list, output): assert flattened == output +@pytest.mark.unit @pytest.mark.parametrize( "test_bool,output", [ @@ -54,6 +57,7 @@ def test_convert_to_bool(test_bool, output): assert converted == output +@pytest.mark.unit @pytest.mark.parametrize( "test_list,output", [ @@ -69,11 +73,13 @@ def test_validate_str_iterable(test_list, output): assert validated == output +@pytest.mark.unit def test_validate_str_iterable_negative(): with pytest.raises(ValueError): validate_str_iterable("will fail", "") +@pytest.mark.unit @pytest.mark.parametrize( "x_cur,x0,x1,y0,y1,inter_func,out", [ @@ -93,11 +99,13 @@ def test_interpolate(x_cur, x0, x1, y0, y1, inter_func, out): assert abs(out - interpolated) < 0.01 +@pytest.mark.unit def test_pass_kwargs_tuples(): kwargs = parse_kwarg_tuples(("--input_1", 1, "--input_2", "two", "--input_3", "2")) assert kwargs == dict(input_1=1, input_2="two", input_3=2) +@pytest.mark.unit def test_getattr_chain(): base = SimpleNamespace() base.a = None @@ -129,6 +137,7 @@ def test_getattr_chain(): getattr_chain(base, "b.d.dne") +@pytest.mark.unit def test_DisableQuantization(): model = torch.nn.Linear(1, 1) with DisableQuantization(model): @@ -136,6 +145,7 @@ def test_DisableQuantization(): assert model.quantization_enabled +@pytest.mark.unit def test_calibration_forward_context(): model = torch.nn.Linear(1, 1) model.config = SimpleNamespace() @@ -148,3 +158,35 @@ def test_calibration_forward_context(): assert torch.is_grad_enabled() assert model.quantization_enabled assert model.config.use_cache + + +@pytest.mark.unit +def test_patch_attr(): + # patch, original value + obj = SimpleNamespace() + obj.attribute = "original" + with patch_attr(obj, "attribute", "patched"): + assert obj.attribute == "patched" + obj.attribute = "modified" + assert obj.attribute == "original" + + # patch, no original attribute + obj = SimpleNamespace() + with patch_attr(obj, "attribute", "patched"): + assert obj.attribute == "patched" + obj.attribute = "modified" + assert not hasattr(obj, "attribute") + + # no patch, original value + obj = SimpleNamespace() + obj.attribute = "original" + with patch_attr(obj, "attribute"): + assert obj.attribute == "original" + obj.attribute = "modified" + assert obj.attribute == "original" + + # no patch, no original attribute + obj = SimpleNamespace() + with patch_attr(obj, "attribute"): + obj.attribute = "modified" + assert not hasattr(obj, "attribute")