Skip to content

Commit

Permalink
functional compiled autograd (pytorch#144707)
Browse files Browse the repository at this point in the history
This PR squashes together the following commits:

pytorch#144115
pytorch#143417
pytorch#143405
pytorch#143387
pytorch#143304
pytorch#143296

This is a refactor of compiled autograd to use "functional autograd". The end goal is that it gets compiled autograd's initial capture to stop specializing on Tensor metadata, therefore allowing compiled autograd to better handle Tensor subclasses.

For more information, please read the commit messages for each PR.

Pull Request resolved: pytorch#144707
Approved by: https://github.com/bdhirsh, https://github.com/xmfan, https://github.com/jansel
  • Loading branch information
zou3519 authored and pytorchmergebot committed Jan 27, 2025
1 parent 87fdadd commit ea141d8
Show file tree
Hide file tree
Showing 28 changed files with 1,809 additions and 223 deletions.
10 changes: 10 additions & 0 deletions aten/src/ATen/TensorGeometry.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ struct TORCH_API TensorGeometry {
has_symbolic_sizes_strides_(
t.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {}

explicit TensorGeometry(
std::vector<at::SymInt> sizes,
std::vector<at::SymInt> strides,
at::SymInt storage_offset)
: sizes_(std::move(sizes)),
strides_(std::move(strides)),
storage_offset_(std::move(storage_offset)) {
recompute();
}

// true if the tensor is contiguous
bool is_contiguous() const;

Expand Down
1 change: 1 addition & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ core_trainer_sources = [
"torch/csrc/autograd/variable.cpp",
"torch/csrc/autograd/utils/warnings.cpp",
"torch/csrc/autograd/jit_decomp_interface.cpp",
"torch/csrc/dynamo/compiled_autograd.cpp",
"torch/csrc/jit/frontend/name_mangler.cpp",
"torch/csrc/jit/ir/type_hashing.cpp",
"torch/csrc/jit/serialization/pickler.cpp",
Expand Down
54 changes: 34 additions & 20 deletions test/dynamo/test_backward_higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,23 +121,30 @@ def fn(x, y):
out.backward(grad_out)
actual = normalize_gm(graph.print_readable(False))
self.assertEqual(x.grad, grad_out * grad_out)
self.assertExpectedInline(
actual,
"""\
if backend in ["aot_eager", "inductor"]:
self.assertExpectedInline(
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_inputs_ : list):
l_inputs_ = L_inputs_
getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None
getitem: "f32[2]" = l_inputs_[0]; l_inputs_ = None
new_grad: "f32[s0]" = torch.clone(getitem)
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [2], False)]); getitem = None
getitem_3: "f32[2]" = validate_outputs[0]; validate_outputs = None
result: "f32[s0]" = getitem * getitem; getitem = None
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_3); getitem_3 = None
getitem_5: "f32[2]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
new_grad_1: "f32[s0]" = torch.clone(result); result = None
new_grad: "f32[2]" = torch.clone(getitem_5)
result: "f32[2]" = getitem_5 * getitem_5; getitem_5 = None
new_grad_1: "f32[2]" = torch.clone(result); result = None
return (new_grad, new_grad_1)
""",
)
)

graph = None

Expand All @@ -162,7 +169,7 @@ def inner_compiler(gm_, example_inputs_):
gm, backend=inner_compiler, fullgraph=True, dynamic=True
)

for backend in ["eager", "aot_eager", "inductor"]:
for backend in ["inductor"]:
torch._dynamo.reset()
x = torch.tensor([0.5, 0.5], requires_grad=True)
y = torch.tensor([0.5, 0.5], requires_grad=True)
Expand All @@ -187,26 +194,33 @@ def fn(x, y):
actual = normalize_gm(graph.print_readable(False))
self.assertEqual(obj.counter, 1)
self.assertEqual(x.grad, grad_out + grad_out)
self.assertExpectedInline(
actual,
"""\
if backend in ["aot_eager", "inductor"]:
self.assertExpectedInline(
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_inputs_ : list, L_hooks_0_keywords_fn_keywords_obj_counter: "Sym(s1)"):
def forward(self, L_inputs_ : list, L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s1)"):
l_inputs_ = L_inputs_
l_hooks_0_keywords_fn_keywords_obj_counter = L_hooks_0_keywords_fn_keywords_obj_counter
l_hooks_1_keywords_fn_keywords_obj_counter = L_hooks_1_keywords_fn_keywords_obj_counter
getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None
getitem: "f32[2]" = l_inputs_[0]; l_inputs_ = None
new_grad: "f32[s0]" = torch.clone(getitem)
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [2], False)]); getitem = None
getitem_3: "f32[2]" = validate_outputs[0]; validate_outputs = None
add: "Sym(s1 + 1)" = l_hooks_0_keywords_fn_keywords_obj_counter + 1; l_hooks_0_keywords_fn_keywords_obj_counter = None
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_3); getitem_3 = None
getitem_5: "f32[2]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
result: "f32[s0]" = getitem * getitem; getitem = None
new_grad: "f32[2]" = torch.clone(getitem_5)
new_grad_1: "f32[s0]" = torch.clone(result); result = None
add: "Sym(s1 + 1)" = l_hooks_1_keywords_fn_keywords_obj_counter + 1; l_hooks_1_keywords_fn_keywords_obj_counter = None
result: "f32[2]" = getitem_5 * getitem_5; getitem_5 = None
new_grad_1: "f32[2]" = torch.clone(result); result = None
return (new_grad, new_grad_1, add)
""",
)
)

out = fn(x, y)
out.backward(grad_out)
Expand Down
135 changes: 125 additions & 10 deletions test/inductor/test_compiled_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torch._dynamo import compiled_autograd, config
from torch._dynamo.backends.debugging import aot_eager
from torch._dynamo.device_interface import get_interface_for_device
from torch._dynamo.testing import normalize_gm
from torch._dynamo.utils import counters
from torch._inductor import config as inductor_config
from torch._inductor.test_case import run_tests, TestCase
Expand Down Expand Up @@ -2821,8 +2822,11 @@ def test_cudagraphs_cpu_scalar_used_in_cpp_custom_op(self, load_inline):
opt_bwd()

self.assertEqual(counters["compiled_autograd"]["captures"], 1)
# always safe to move, since we trace into the autograd::function bwd and can see if it's only used by aten ops
self.assertEqual(counters["inductor"]["cudagraph_skips"], 0)
# Compiled autograd's initial capture lifts custom C++ autograd::Function bwd instead of tracing
# into it. We must skip since we do not know if the cpu scalar will be used only in ATen/prim ops.
# In the future, we can consider having a cpu scalar movement pass sometime after we trace
# into the custom C++ autograd::Function (like in AOTDispatcher)
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)

def test_logs(self):
logs, ctx = logs_to_string(
Expand Down Expand Up @@ -2941,12 +2945,11 @@ def forward(model, x):

expected_logs = [
"code: CompiledFunctionBackward (NodeCall 2)",
"code: CompiledFunctionBackward0 (NodeCall 2)",
"aot0_primals_3",
"aot0_relu",
"aot0_le",
"aot0_permute_2",
"code: CompiledFunctionBackward0 (NodeCall 2)",
"aot0_tangents_1",
"aot0_full_default",
"aot0_where",
"aot0_mm",
Expand Down Expand Up @@ -2996,20 +2999,17 @@ def f(x):

expected_logs = [
"CompiledFunctionBackward1",
"aot1_tangents_1",
"aot1_sin_1",
"aot1_primals_2",
"aot1_neg",
"aot0_tangents_2",
"aot1_cos_1",
"aot1_primals_1",
"aot0_tangents_1",
"CompiledFunctionBackward0",
"aot0_sin_1",
"aot0_neg",
"aot0_sin",
"aot0_mul",
"aot0_cos_1",
"aot0_mul_1",
"aot0_cos",
"aot0_add",
]

Expand Down Expand Up @@ -3154,6 +3154,120 @@ def fn():

self.assertEqual(sum(1 for e in unexpected_logs if e in logs.getvalue()), 0)

def test_tensor_subclass_basic(self):
from torch.testing._internal.two_tensor import TwoTensor, TwoTensorMode

with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
lib.define("to_twotensor(Tensor a, Tensor b) -> Tensor")
lib.define("from_twotensor(Tensor c) -> (Tensor, Tensor)")

def to_twotensor_backward(ctx, grad):
return torch.ops.mylib.from_twotensor(grad)

def from_twotensor_backward(ctx, grad_a, grad_b):
raise AssertionError("shouldn't get hit")

torch.library.register_autograd(
"mylib::to_twotensor", to_twotensor_backward, lib=lib
)
torch.library.register_autograd(
"mylib::from_twotensor", from_twotensor_backward, lib=lib
)

@torch.library.register_torch_dispatch(
"mylib::to_twotensor", TwoTensorMode, lib=lib
)
def _(_0, _1, _2, args, kwargs):
assert not kwargs
a, b = args
return TwoTensor(a.clone(), b.clone())

@torch.library.register_torch_dispatch(
"mylib::from_twotensor", TwoTensor, lib=lib
)
def _(_0, _1, _2, args, kwargs):
assert not kwargs
(c,) = args
return c.a.clone(), c.b.clone()

@torch.compile(backend="aot_eager", fullgraph=True)
def fn(x):
return x * x + 2

param1 = torch.randn(4, 4, requires_grad=True)
param2 = torch.randn(4, 4, requires_grad=True)
with TwoTensorMode():
x = torch.ops.mylib.to_twotensor(param1, param2)

inner_compiler_fn = make_compiler_fn(fullgraph=True, backend="aot_eager")
graphs = []

def compiler_fn(gm):
graphs.append(gm)
return inner_compiler_fn(gm)

with compiled_autograd._enable(compiler_fn):
res = fn(x)
res.sum().backward()

self.assertEqual(param1.grad, 2 * param1)
self.assertEqual(param2.grad, 2 * param2)
self.assertEqual(len(graphs), 1)

graph_code = normalize_gm(graphs[0].print_readable(print_output=False))
# The graph should have make_subclass calls in it.
self.assertExpectedInline(
graph_code,
"""\
class CompiledAutograd0(torch.nn.Module):
def forward(self, inputs, sizes, scalars, hooks):
getitem = inputs[0]
getitem_1 = inputs[1]
getitem_2 = inputs[2]
getitem_3 = inputs[3]
getitem_4 = inputs[4]; inputs = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], True)]); getitem = None
getitem_5 = validate_outputs[0]; validate_outputs = None
sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_5], [True], [4, 4]); getitem_5 = None
getitem_6 = sum_backward0[0]; sum_backward0 = None
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_6], [((None, None, device(type='cpu'), 6, 0, None), [4, 4], True)]); getitem_6 = None
getitem_7 = validate_outputs_1[0]; validate_outputs_1 = None
getitem_8 = hooks[0]; getitem_8 = None
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((getitem_1, getitem_2), [], getitem_7); getitem_1 = getitem_2 = getitem_7 = None
aot0_primals_1 = call_aot_bwd_prologue[0]
aot0_primals_2 = call_aot_bwd_prologue[1]
aot0_tangents_1 = call_aot_bwd_prologue[2]
aot0_tangents_2 = call_aot_bwd_prologue[3]; call_aot_bwd_prologue = None
aot0_mul_2 = torch.ops.aten.mul.Tensor(aot0_tangents_1, aot0_primals_1); aot0_tangents_1 = aot0_primals_1 = None
aot0_mul_3 = torch.ops.aten.mul.Tensor(aot0_tangents_2, aot0_primals_2); aot0_tangents_2 = aot0_primals_2 = None
aot0_add_2 = torch.ops.aten.add.Tensor(aot0_mul_2, aot0_mul_2); aot0_mul_2 = None
aot0_add_3 = torch.ops.aten.add.Tensor(aot0_mul_3, aot0_mul_3); aot0_mul_3 = None
make_subclass = torch__dynamo_compiled_autograd_make_subclass(aot0_add_2, aot0_add_3); aot0_add_2 = aot0_add_3 = None
getitem_13 = hooks[1]; hooks = None
call_backward = torch__dynamo_external_utils_call_backward(getitem_13, (), make_subclass); getitem_13 = make_subclass = None
getitem_16 = call_backward[0]
getitem_17 = call_backward[1]; call_backward = None
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_16, getitem_17], [((None, None, device(type='cpu'), 6, 0, None), [4, 4], False), ((None, None, device(type='cpu'), 6, 0, None), [4, 4], False)]); getitem_16 = getitem_17 = None
getitem_19 = validate_outputs_2[0]
accumulate_grad__1 = torch.ops.inductor.accumulate_grad_.default(getitem_4, getitem_19); getitem_4 = getitem_19 = accumulate_grad__1 = None
getitem_20 = validate_outputs_2[1]; validate_outputs_2 = None
accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_3, getitem_20); getitem_3 = getitem_20 = accumulate_grad_ = None
_exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None
return []
""", # noqa: B950
)

# https://github.com/pytorch/pytorch/issues/138920
def test_compiled_autograd_does_not_specialize_on_bw_symints(self):
class Mod(torch.nn.Module):
Expand Down Expand Up @@ -3247,7 +3361,7 @@ def inner_compiler(gm_, example_inputs_):
# because we ignore all of these guards anyway in CA.
# Once we stop using make_fx in CA, we won't have to worry about this specialization.
view_nodes = graphs[1].graph.find_nodes(
op="call_function", target=torch.ops.aten.view.default
op="call_function", target=torch.ops.aten.reshape.default
)
# First 2 view nodes have a first argument that is a SymInt, not an int burned into the graph
self.assertTrue(isinstance(view_nodes[0].args[1][0], torch.fx.Node))
Expand Down Expand Up @@ -3640,6 +3754,7 @@ def wrap_test_class(orig_cls):
"test_tp_compile_comm_reordering",
"test_unwrap_async_collective_tensor_tangent",
# Uncategorized
"test_not_implemented_grad", # Dynamo changes the types of exceptions
}

if not HAS_CUDA:
Expand Down
4 changes: 3 additions & 1 deletion test/inductor/test_distributed_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,9 @@ def test_module_backward_hooks_eager(self):
self.assertEqual(fw_cnt.frame_count, 1)
self.assertEqual(fw_cnt.op_count, 5)
self.assertEqual(bw_cnt.frame_count, 2) # grad=None and grad!=None
self.assertEqual(bw_cnt.op_count, 48)
self.assertEqual(
bw_cnt.op_count, 72
) # Number of ops in the Dynamo-produced graphs

def test_module_backward_hooks_aot(self):
m1, inp1 = init_module_bw_hooks(True)
Expand Down
Loading

0 comments on commit ea141d8

Please sign in to comment.