Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Utils] Replace preserve_attr with patch_attr #1187

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions src/llmcompressor/pipelines/sequential/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from transformers.utils.fx import HFTracer

from llmcompressor.modifiers.utils.hooks import HooksMixin
from llmcompressor.utils.helpers import calibration_forward_context, preserve_attr
from llmcompressor.utils.helpers import calibration_forward_context, patch_attr

__all__ = ["trace_subgraphs", "Subgraph"]

Expand Down Expand Up @@ -135,15 +135,14 @@ def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool:

def trace(self, root: Union[Module, Callable], *args, **kwargs) -> Graph:
if isinstance(root, Module):
with preserve_attr(type(root), "forward"):
# due to a bug in Tracer.create_args_for_root (_patch_function),
# we must unwrap function wrappers prior to tracing, for example
# the `deprecate_kwarg` by transformers which wraps forward

# we override the class method because the
# class method is the one being traced
type(root).forward = inspect.unwrap(type(root).forward)

# due to a bug in Tracer.create_args_for_root (_patch_function),
# we must unwrap function wrappers prior to tracing, for example
# the `deprecate_kwarg` by transformers which wraps forward
unwrapped_forward = inspect.unwrap(type(root).forward)

# we override the class method because the
# class method is the one being traced
with patch_attr(type(root), "forward", unwrapped_forward):
return super().trace(root, *args, **kwargs)

else:
Expand Down
38 changes: 34 additions & 4 deletions src/llmcompressor/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
"DisableQuantization",
"eval_context",
"calibration_forward_context",
"preserve_attr",
"patch_attr",
]


Expand Down Expand Up @@ -1132,9 +1132,39 @@ def calibration_forward_context(model: PreTrainedModel):


@contextlib.contextmanager
def preserve_attr(base: object, attr: str):
value = getattr(base, attr)
def patch_attr(base: object, attr: str, *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why use *args and **kwargs here? It seems like you can just have a patch_value: Any = None input rather than trying to infer based on len(args) or "value" in kwargs

"""
:param base: object which has the attribute to patch
:param attr: name of the the attribute to patch
:param value: optional value used to replace attribute value. If none is provided,
then the value is not replaced and the original value is restored upon exit

Usage:
>>> from types import SimpleNamespace
>>> obj = SimpleNamespace()
>>> with patch_attr(obj, "attribute", "value"):
... assert obj.attribute == "value"
>>> assert not hasattr(obj, "attribute")
"""
# get value to patch with (handles `None` case)
has_patched_value = len(args) > 0 or "value" in kwargs
if len(args) > 0 and "value" in kwargs:
raise TypeError("Given two arguments for `value`")
patched_value = args[0] if len(args) > 0 else kwargs.pop("value", None)

# check if attribute already exists
has_original_value = hasattr(base, attr)
if has_original_value:
value = getattr(base, attr)

# override value
if has_patched_value:
setattr(base, attr, patched_value)
try:
yield
finally:
setattr(base, attr, value)
# restore value
if has_original_value:
setattr(base, attr, value)
elif hasattr(base, attr):
delattr(base, attr)
42 changes: 42 additions & 0 deletions tests/llmcompressor/utils/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
getattr_chain,
interpolate,
parse_kwarg_tuples,
patch_attr,
validate_str_iterable,
)


@pytest.mark.unit
@pytest.mark.parametrize(
"test_list,output",
[
Expand All @@ -30,6 +32,7 @@ def test_flatten_iterable(test_list, output):
assert flattened == output


@pytest.mark.unit
@pytest.mark.parametrize(
"test_bool,output",
[
Expand All @@ -54,6 +57,7 @@ def test_convert_to_bool(test_bool, output):
assert converted == output


@pytest.mark.unit
@pytest.mark.parametrize(
"test_list,output",
[
Expand All @@ -69,11 +73,13 @@ def test_validate_str_iterable(test_list, output):
assert validated == output


@pytest.mark.unit
def test_validate_str_iterable_negative():
with pytest.raises(ValueError):
validate_str_iterable("will fail", "")


@pytest.mark.unit
@pytest.mark.parametrize(
"x_cur,x0,x1,y0,y1,inter_func,out",
[
Expand All @@ -93,11 +99,13 @@ def test_interpolate(x_cur, x0, x1, y0, y1, inter_func, out):
assert abs(out - interpolated) < 0.01


@pytest.mark.unit
def test_pass_kwargs_tuples():
kwargs = parse_kwarg_tuples(("--input_1", 1, "--input_2", "two", "--input_3", "2"))
assert kwargs == dict(input_1=1, input_2="two", input_3=2)


@pytest.mark.unit
def test_getattr_chain():
base = SimpleNamespace()
base.a = None
Expand Down Expand Up @@ -129,13 +137,15 @@ def test_getattr_chain():
getattr_chain(base, "b.d.dne")


@pytest.mark.unit
def test_DisableQuantization():
model = torch.nn.Linear(1, 1)
with DisableQuantization(model):
assert not model.quantization_enabled
assert model.quantization_enabled


@pytest.mark.unit
def test_calibration_forward_context():
model = torch.nn.Linear(1, 1)
model.config = SimpleNamespace()
Expand All @@ -148,3 +158,35 @@ def test_calibration_forward_context():
assert torch.is_grad_enabled()
assert model.quantization_enabled
assert model.config.use_cache


@pytest.mark.unit
def test_patch_attr():
# patch, original value
obj = SimpleNamespace()
obj.attribute = "original"
with patch_attr(obj, "attribute", "patched"):
assert obj.attribute == "patched"
obj.attribute = "modified"
assert obj.attribute == "original"

# patch, no original attribute
obj = SimpleNamespace()
with patch_attr(obj, "attribute", "patched"):
assert obj.attribute == "patched"
obj.attribute = "modified"
assert not hasattr(obj, "attribute")

# no patch, original value
obj = SimpleNamespace()
obj.attribute = "original"
with patch_attr(obj, "attribute"):
assert obj.attribute == "original"
obj.attribute = "modified"
assert obj.attribute == "original"

# no patch, no original attribute
obj = SimpleNamespace()
with patch_attr(obj, "attribute"):
obj.attribute = "modified"
assert not hasattr(obj, "attribute")