Skip to content

Commit

Permalink
[reland][dynamo][guards] Consider tensors as immutable for dict tag m…
Browse files Browse the repository at this point in the history
…atches (pytorch#141085)

Reland - pytorch#139560

As mentioned in pytorch#130341, using `static py::object` can lead to segfaults. I suspect this is the reason for the import system error seen internally (https://www.internalfb.com/sevmanager/view/469592). In this PR, I am removing the `static` part. This is fine and also the right thing to do because this will catch if user changes the flag in the same process for compiling two different functions.

Unfortunately, there is no easy way to trigger this segfault, so I can't write a test.

Pull Request resolved: pytorch#141085
Approved by: https://github.com/jansel

Co-authored-by: William Wen <[email protected]>
  • Loading branch information
2 people authored and pytorchmergebot committed Dec 6, 2024
1 parent ce22a01 commit 8bfc009
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 1 deletion.
48 changes: 48 additions & 0 deletions test/dynamo/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3166,6 +3166,54 @@ def fn(x):
res = opt_fn(x)
self.assertEqual(ref, res)

@patch.object(
torch._dynamo.config, "skip_tensor_guards_with_matching_dict_tags", False
)
def test_param_requires_grad(self):
def adjust_model(model):
to_freeze = model.num_iter % 2 == 0
if to_freeze:
for param in model.layer2.parameters():
param.requires_grad = False
else:
for param in model.layer2.parameters():
param.requires_grad = True

class MyModule(torch.nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()

self.layer1 = torch.nn.Linear(hidden_size, hidden_size)
self.layer2 = torch.nn.Linear(hidden_size, hidden_size)

self.num_iter = 0

def forward(self, x):
x = self.layer2(x + self.layer1.bias)

self.num_iter += 1
return x

input_size = 1024
hidden_size = 1024
output_size = 1
num_samples = 2048
features = torch.randn(num_samples, input_size)

model = MyModule(input_size, hidden_size, output_size)

cnt = torch._dynamo.testing.CompileCounter()
opt_model = torch.compile(model, backend=cnt, fullgraph=True)

for _ in range(3):
model.zero_grad(True)
adjust_model(model)
res = opt_model(features)
res.sum().backward()

# Check that we have recompiled twice, which leads to 3 frames
self.assertEqual(cnt.frame_count, 3)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
4 changes: 4 additions & 0 deletions torch/_dynamo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,10 @@ def _get_optimize_ddp_mode():
# notice and lead to incorrect result.
skip_no_tensor_aliasing_guards_on_parameters = True

# Considers a tensor immutable if it is one of the values of a dictionary, and
# the dictionary tag is same across invocation calls.
skip_tensor_guards_with_matching_dict_tags = True

# If True, raises exception if TorchDynamo is called with a context manager
raise_on_ctx_manager_usage = True

Expand Down
9 changes: 8 additions & 1 deletion torch/csrc/dynamo/guards.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,12 @@ std::string get_exception_message() {
}

bool is_immutable_object(py::handle example_value) {
py::object config_module = py::module_::import("torch._dynamo.config");

bool is_tensor_immutable =
config_module.attr("skip_tensor_guards_with_matching_dict_tags")
.cast<bool>();

if (PyTuple_Check(example_value.ptr())) {
// Check that each element is immutable
for (Py_ssize_t i = 0; i < PyTuple_Size(example_value.ptr()); ++i) {
Expand All @@ -913,10 +919,11 @@ bool is_immutable_object(py::handle example_value) {
}
return true;
}

return PyLong_Check(example_value.ptr()) ||
PyFloat_Check(example_value.ptr()) || PyBool_Check(example_value.ptr()) ||
PyUnicode_Check(example_value.ptr()) ||
THPVariable_Check(example_value.ptr());
(is_tensor_immutable && THPVariable_Check(example_value.ptr()));
}

bool is_parameter(py::handle tensor) {
Expand Down

0 comments on commit 8bfc009

Please sign in to comment.