From ea141d8134ad87c9ed37642f61d7254df604873e Mon Sep 17 00:00:00 2001 From: rzou Date: Mon, 27 Jan 2025 05:20:56 +0000 Subject: [PATCH] functional compiled autograd (#144707) This PR squashes together the following commits: https://github.com/pytorch/pytorch/pull/144115 https://github.com/pytorch/pytorch/pull/143417 https://github.com/pytorch/pytorch/pull/143405 https://github.com/pytorch/pytorch/pull/143387 https://github.com/pytorch/pytorch/pull/143304 https://github.com/pytorch/pytorch/pull/143296 This is a refactor of compiled autograd to use "functional autograd". The end goal is that it gets compiled autograd's initial capture to stop specializing on Tensor metadata, therefore allowing compiled autograd to better handle Tensor subclasses. For more information, please read the commit messages for each PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144707 Approved by: https://github.com/bdhirsh, https://github.com/xmfan, https://github.com/jansel --- aten/src/ATen/TensorGeometry.h | 10 + build_variables.bzl | 1 + test/dynamo/test_backward_higher_order_ops.py | 54 +- test/inductor/test_compiled_autograd.py | 135 ++++- test/inductor/test_distributed_patterns.py | 4 +- tools/autograd/gen_autograd_functions.py | 68 ++- tools/autograd/templates/Functions.cpp | 24 + torch/_dynamo/compiled_autograd.py | 395 ++++++++++++- torch/_dynamo/external_utils.py | 8 + torch/_dynamo/polyfills/__init__.py | 2 + torch/_dynamo/trace_rules.py | 1 + .../_aot_autograd/runtime_wrappers.py | 33 +- .../_aot_autograd/subclass_utils.py | 15 +- torch/autograd/function.py | 3 + torch/csrc/autograd/custom_function.cpp | 25 + torch/csrc/autograd/custom_function.h | 211 +++++-- torch/csrc/autograd/engine.cpp | 13 + torch/csrc/autograd/engine.h | 2 + torch/csrc/autograd/function.h | 10 + torch/csrc/autograd/function_hook.h | 1 + torch/csrc/autograd/functions/tensor.cpp | 153 +++-- torch/csrc/autograd/functions/tensor.h | 1 + torch/csrc/autograd/init.cpp | 5 + torch/csrc/autograd/python_function.cpp | 51 +- torch/csrc/autograd/python_function.h | 2 + torch/csrc/dynamo/compiled_autograd.cpp | 27 + torch/csrc/dynamo/compiled_autograd.h | 536 ++++++++++++++++++ .../csrc/dynamo/python_compiled_autograd.cpp | 242 +++++++- 28 files changed, 1809 insertions(+), 223 deletions(-) create mode 100644 torch/csrc/dynamo/compiled_autograd.cpp diff --git a/aten/src/ATen/TensorGeometry.h b/aten/src/ATen/TensorGeometry.h index 41f14a15ba99c..06a064063c4e2 100644 --- a/aten/src/ATen/TensorGeometry.h +++ b/aten/src/ATen/TensorGeometry.h @@ -37,6 +37,16 @@ struct TORCH_API TensorGeometry { has_symbolic_sizes_strides_( t.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {} + explicit TensorGeometry( + std::vector sizes, + std::vector strides, + at::SymInt storage_offset) + : sizes_(std::move(sizes)), + strides_(std::move(strides)), + storage_offset_(std::move(storage_offset)) { + recompute(); + } + // true if the tensor is contiguous bool is_contiguous() const; diff --git a/build_variables.bzl b/build_variables.bzl index 8bd8ad3a8df0a..a95c03cd0b345 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -138,6 +138,7 @@ core_trainer_sources = [ "torch/csrc/autograd/variable.cpp", "torch/csrc/autograd/utils/warnings.cpp", "torch/csrc/autograd/jit_decomp_interface.cpp", + "torch/csrc/dynamo/compiled_autograd.cpp", "torch/csrc/jit/frontend/name_mangler.cpp", "torch/csrc/jit/ir/type_hashing.cpp", "torch/csrc/jit/serialization/pickler.cpp", diff --git a/test/dynamo/test_backward_higher_order_ops.py b/test/dynamo/test_backward_higher_order_ops.py index 14e3f2e044c10..2aa5ee3e71894 100644 --- a/test/dynamo/test_backward_higher_order_ops.py +++ b/test/dynamo/test_backward_higher_order_ops.py @@ -121,23 +121,30 @@ def fn(x, y): out.backward(grad_out) actual = normalize_gm(graph.print_readable(False)) self.assertEqual(x.grad, grad_out * grad_out) - self.assertExpectedInline( - actual, - """\ + if backend in ["aot_eager", "inductor"]: + self.assertExpectedInline( + actual, + """\ class GraphModule(torch.nn.Module): def forward(self, L_inputs_ : list): l_inputs_ = L_inputs_ - getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None + getitem: "f32[2]" = l_inputs_[0]; l_inputs_ = None - new_grad: "f32[s0]" = torch.clone(getitem) + validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [2], False)]); getitem = None + getitem_3: "f32[2]" = validate_outputs[0]; validate_outputs = None - result: "f32[s0]" = getitem * getitem; getitem = None + call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_3); getitem_3 = None + getitem_5: "f32[2]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None - new_grad_1: "f32[s0]" = torch.clone(result); result = None + new_grad: "f32[2]" = torch.clone(getitem_5) + + result: "f32[2]" = getitem_5 * getitem_5; getitem_5 = None + + new_grad_1: "f32[2]" = torch.clone(result); result = None return (new_grad, new_grad_1) """, - ) + ) graph = None @@ -162,7 +169,7 @@ def inner_compiler(gm_, example_inputs_): gm, backend=inner_compiler, fullgraph=True, dynamic=True ) - for backend in ["eager", "aot_eager", "inductor"]: + for backend in ["inductor"]: torch._dynamo.reset() x = torch.tensor([0.5, 0.5], requires_grad=True) y = torch.tensor([0.5, 0.5], requires_grad=True) @@ -187,26 +194,33 @@ def fn(x, y): actual = normalize_gm(graph.print_readable(False)) self.assertEqual(obj.counter, 1) self.assertEqual(x.grad, grad_out + grad_out) - self.assertExpectedInline( - actual, - """\ + if backend in ["aot_eager", "inductor"]: + self.assertExpectedInline( + actual, + """\ class GraphModule(torch.nn.Module): - def forward(self, L_inputs_ : list, L_hooks_0_keywords_fn_keywords_obj_counter: "Sym(s1)"): + def forward(self, L_inputs_ : list, L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s1)"): l_inputs_ = L_inputs_ - l_hooks_0_keywords_fn_keywords_obj_counter = L_hooks_0_keywords_fn_keywords_obj_counter + l_hooks_1_keywords_fn_keywords_obj_counter = L_hooks_1_keywords_fn_keywords_obj_counter - getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None + getitem: "f32[2]" = l_inputs_[0]; l_inputs_ = None - new_grad: "f32[s0]" = torch.clone(getitem) + validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [2], False)]); getitem = None + getitem_3: "f32[2]" = validate_outputs[0]; validate_outputs = None - add: "Sym(s1 + 1)" = l_hooks_0_keywords_fn_keywords_obj_counter + 1; l_hooks_0_keywords_fn_keywords_obj_counter = None + call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_3); getitem_3 = None + getitem_5: "f32[2]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None - result: "f32[s0]" = getitem * getitem; getitem = None + new_grad: "f32[2]" = torch.clone(getitem_5) - new_grad_1: "f32[s0]" = torch.clone(result); result = None + add: "Sym(s1 + 1)" = l_hooks_1_keywords_fn_keywords_obj_counter + 1; l_hooks_1_keywords_fn_keywords_obj_counter = None + + result: "f32[2]" = getitem_5 * getitem_5; getitem_5 = None + + new_grad_1: "f32[2]" = torch.clone(result); result = None return (new_grad, new_grad_1, add) """, - ) + ) out = fn(x, y) out.backward(grad_out) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index e08b19b3f2b7f..87fceb9bf624e 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -22,6 +22,7 @@ from torch._dynamo import compiled_autograd, config from torch._dynamo.backends.debugging import aot_eager from torch._dynamo.device_interface import get_interface_for_device +from torch._dynamo.testing import normalize_gm from torch._dynamo.utils import counters from torch._inductor import config as inductor_config from torch._inductor.test_case import run_tests, TestCase @@ -2821,8 +2822,11 @@ def test_cudagraphs_cpu_scalar_used_in_cpp_custom_op(self, load_inline): opt_bwd() self.assertEqual(counters["compiled_autograd"]["captures"], 1) - # always safe to move, since we trace into the autograd::function bwd and can see if it's only used by aten ops - self.assertEqual(counters["inductor"]["cudagraph_skips"], 0) + # Compiled autograd's initial capture lifts custom C++ autograd::Function bwd instead of tracing + # into it. We must skip since we do not know if the cpu scalar will be used only in ATen/prim ops. + # In the future, we can consider having a cpu scalar movement pass sometime after we trace + # into the custom C++ autograd::Function (like in AOTDispatcher) + self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) def test_logs(self): logs, ctx = logs_to_string( @@ -2941,12 +2945,11 @@ def forward(model, x): expected_logs = [ "code: CompiledFunctionBackward (NodeCall 2)", + "code: CompiledFunctionBackward0 (NodeCall 2)", "aot0_primals_3", "aot0_relu", "aot0_le", "aot0_permute_2", - "code: CompiledFunctionBackward0 (NodeCall 2)", - "aot0_tangents_1", "aot0_full_default", "aot0_where", "aot0_mm", @@ -2996,20 +2999,17 @@ def f(x): expected_logs = [ "CompiledFunctionBackward1", - "aot1_tangents_1", "aot1_sin_1", - "aot1_primals_2", "aot1_neg", "aot0_tangents_2", "aot1_cos_1", - "aot1_primals_1", "aot0_tangents_1", "CompiledFunctionBackward0", + "aot0_sin_1", "aot0_neg", - "aot0_sin", "aot0_mul", + "aot0_cos_1", "aot0_mul_1", - "aot0_cos", "aot0_add", ] @@ -3154,6 +3154,120 @@ def fn(): self.assertEqual(sum(1 for e in unexpected_logs if e in logs.getvalue()), 0) + def test_tensor_subclass_basic(self): + from torch.testing._internal.two_tensor import TwoTensor, TwoTensorMode + + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + lib.define("to_twotensor(Tensor a, Tensor b) -> Tensor") + lib.define("from_twotensor(Tensor c) -> (Tensor, Tensor)") + + def to_twotensor_backward(ctx, grad): + return torch.ops.mylib.from_twotensor(grad) + + def from_twotensor_backward(ctx, grad_a, grad_b): + raise AssertionError("shouldn't get hit") + + torch.library.register_autograd( + "mylib::to_twotensor", to_twotensor_backward, lib=lib + ) + torch.library.register_autograd( + "mylib::from_twotensor", from_twotensor_backward, lib=lib + ) + + @torch.library.register_torch_dispatch( + "mylib::to_twotensor", TwoTensorMode, lib=lib + ) + def _(_0, _1, _2, args, kwargs): + assert not kwargs + a, b = args + return TwoTensor(a.clone(), b.clone()) + + @torch.library.register_torch_dispatch( + "mylib::from_twotensor", TwoTensor, lib=lib + ) + def _(_0, _1, _2, args, kwargs): + assert not kwargs + (c,) = args + return c.a.clone(), c.b.clone() + + @torch.compile(backend="aot_eager", fullgraph=True) + def fn(x): + return x * x + 2 + + param1 = torch.randn(4, 4, requires_grad=True) + param2 = torch.randn(4, 4, requires_grad=True) + with TwoTensorMode(): + x = torch.ops.mylib.to_twotensor(param1, param2) + + inner_compiler_fn = make_compiler_fn(fullgraph=True, backend="aot_eager") + graphs = [] + + def compiler_fn(gm): + graphs.append(gm) + return inner_compiler_fn(gm) + + with compiled_autograd._enable(compiler_fn): + res = fn(x) + res.sum().backward() + + self.assertEqual(param1.grad, 2 * param1) + self.assertEqual(param2.grad, 2 * param2) + self.assertEqual(len(graphs), 1) + + graph_code = normalize_gm(graphs[0].print_readable(print_output=False)) + # The graph should have make_subclass calls in it. + self.assertExpectedInline( + graph_code, + """\ +class CompiledAutograd0(torch.nn.Module): + def forward(self, inputs, sizes, scalars, hooks): + getitem = inputs[0] + getitem_1 = inputs[1] + getitem_2 = inputs[2] + getitem_3 = inputs[3] + getitem_4 = inputs[4]; inputs = None + + validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], True)]); getitem = None + getitem_5 = validate_outputs[0]; validate_outputs = None + + sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_5], [True], [4, 4]); getitem_5 = None + getitem_6 = sum_backward0[0]; sum_backward0 = None + validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_6], [((None, None, device(type='cpu'), 6, 0, None), [4, 4], True)]); getitem_6 = None + getitem_7 = validate_outputs_1[0]; validate_outputs_1 = None + + getitem_8 = hooks[0]; getitem_8 = None + call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((getitem_1, getitem_2), [], getitem_7); getitem_1 = getitem_2 = getitem_7 = None + aot0_primals_1 = call_aot_bwd_prologue[0] + aot0_primals_2 = call_aot_bwd_prologue[1] + aot0_tangents_1 = call_aot_bwd_prologue[2] + aot0_tangents_2 = call_aot_bwd_prologue[3]; call_aot_bwd_prologue = None + + aot0_mul_2 = torch.ops.aten.mul.Tensor(aot0_tangents_1, aot0_primals_1); aot0_tangents_1 = aot0_primals_1 = None + aot0_mul_3 = torch.ops.aten.mul.Tensor(aot0_tangents_2, aot0_primals_2); aot0_tangents_2 = aot0_primals_2 = None + + aot0_add_2 = torch.ops.aten.add.Tensor(aot0_mul_2, aot0_mul_2); aot0_mul_2 = None + aot0_add_3 = torch.ops.aten.add.Tensor(aot0_mul_3, aot0_mul_3); aot0_mul_3 = None + + make_subclass = torch__dynamo_compiled_autograd_make_subclass(aot0_add_2, aot0_add_3); aot0_add_2 = aot0_add_3 = None + + getitem_13 = hooks[1]; hooks = None + call_backward = torch__dynamo_external_utils_call_backward(getitem_13, (), make_subclass); getitem_13 = make_subclass = None + getitem_16 = call_backward[0] + getitem_17 = call_backward[1]; call_backward = None + validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_16, getitem_17], [((None, None, device(type='cpu'), 6, 0, None), [4, 4], False), ((None, None, device(type='cpu'), 6, 0, None), [4, 4], False)]); getitem_16 = getitem_17 = None + getitem_19 = validate_outputs_2[0] + + accumulate_grad__1 = torch.ops.inductor.accumulate_grad_.default(getitem_4, getitem_19); getitem_4 = getitem_19 = accumulate_grad__1 = None + + getitem_20 = validate_outputs_2[1]; validate_outputs_2 = None + + accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_3, getitem_20); getitem_3 = getitem_20 = accumulate_grad_ = None + + _exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None + return [] +""", # noqa: B950 + ) + # https://github.com/pytorch/pytorch/issues/138920 def test_compiled_autograd_does_not_specialize_on_bw_symints(self): class Mod(torch.nn.Module): @@ -3247,7 +3361,7 @@ def inner_compiler(gm_, example_inputs_): # because we ignore all of these guards anyway in CA. # Once we stop using make_fx in CA, we won't have to worry about this specialization. view_nodes = graphs[1].graph.find_nodes( - op="call_function", target=torch.ops.aten.view.default + op="call_function", target=torch.ops.aten.reshape.default ) # First 2 view nodes have a first argument that is a SymInt, not an int burned into the graph self.assertTrue(isinstance(view_nodes[0].args[1][0], torch.fx.Node)) @@ -3640,6 +3754,7 @@ def wrap_test_class(orig_cls): "test_tp_compile_comm_reordering", "test_unwrap_async_collective_tensor_tangent", # Uncategorized + "test_not_implemented_grad", # Dynamo changes the types of exceptions } if not HAS_CUDA: diff --git a/test/inductor/test_distributed_patterns.py b/test/inductor/test_distributed_patterns.py index 9d1a3202f1475..ed50a02d7b274 100644 --- a/test/inductor/test_distributed_patterns.py +++ b/test/inductor/test_distributed_patterns.py @@ -337,7 +337,9 @@ def test_module_backward_hooks_eager(self): self.assertEqual(fw_cnt.frame_count, 1) self.assertEqual(fw_cnt.op_count, 5) self.assertEqual(bw_cnt.frame_count, 2) # grad=None and grad!=None - self.assertEqual(bw_cnt.op_count, 48) + self.assertEqual( + bw_cnt.op_count, 72 + ) # Number of ops in the Dynamo-produced graphs def test_module_backward_hooks_aot(self): m1, inp1 = init_module_bw_hooks(True) diff --git a/tools/autograd/gen_autograd_functions.py b/tools/autograd/gen_autograd_functions.py index fa1d0ce4bc910..176fa68c53b65 100644 --- a/tools/autograd/gen_autograd_functions.py +++ b/tools/autograd/gen_autograd_functions.py @@ -107,6 +107,17 @@ ${body} return grad_inputs; } +inline variable_list ${op}_apply_functional_ivalue(const variable_list& grads, const ivalue_list& args) +{ +#ifdef C10_MOBILE + TORCH_INTERNAL_ASSERT(false, "compiled autograd doesn't work on mobile"); +#else + auto packed_args = PackedArgs(args); + auto needs_input_grad = packed_args.unpack>(); + ${unpack_ivalues} + return ${op}_apply_functional(variable_list(grads), needs_input_grad${,apply_functional_args}); +#endif +} variable_list ${op}::apply(variable_list&& grads) { ${thread_lock} @@ -120,11 +131,35 @@ ${compiled_args} } variable_list ${op}::apply_with_saved(const variable_list& grads, SwapSavedVariables& saved) { - ${apply_with_saved_before} - variable_list result = apply(variable_list(grads)); - ${apply_with_saved_after} - return result; +#ifdef C10_MOBILE + TORCH_INTERNAL_ASSERT(false, "compiled autograd doesn't work on mobile"); +#else + ${apply_with_saved_before} + + static bool called = false; + if (!called) { + called = true; + ${compute_schema} + const auto& pyinterface = torch::dynamo::autograd::getPyCompilerInterface(); + pyinterface->bind_function(saved.get_py_compiler(), name(), ${op}_apply_functional_ivalue, schema); + } + + variable_list output_result; + + PackedArgs packed_args; + ${asserts} + ${unpacks} + ${compute_needs_input_grad} + packed_args.pack(needs_input_grad); + ${get_packed_args} + + output_result = compiled_autograd_apply_functional(packed_args, next_edges(), saved, grads, name()); + + ${apply_with_saved_after} + return output_result; +#endif } + """ ) @@ -993,14 +1028,38 @@ def emit_derivative( f"{T} {x}" for T, x in zip(apply_functional_args_ref_types, apply_functional_args) ] + get_packed_args = "\n".join( + f"packed_args.pack({name});" for name in apply_functional_args + ) + unpack_ivalues = [] + for typ, name in zip(apply_functional_args_ref_types, apply_functional_args): + if typ.endswith("&"): + typ = typ[:-1] + unpack_ivalues.append(f"auto {name} = packed_args.unpack<{typ}>();") + + schema_args = [f"std::array"] + for typ in apply_functional_args_ref_types: + if typ.endswith("&"): + typ = typ[:-1] + if typ.startswith("const"): + typ = typ[5:] + schema_args.append(typ.strip()) + compute_schema = ["std::vector schema = {"] + for schema_arg in schema_args: + compute_schema.append( + f" torch::dynamo::autograd::IValuePacker<{schema_arg}>::packed_type()," + ) + compute_schema.append("};") return template.substitute( unpacks="\n".join(unpack), op=info.op, + compute_schema="\n".join(compute_schema), apply_functional_args=apply_functional_args, apply_functional_args_signature=apply_functional_args_signature, compute_needs_input_grad=compute_needs_input_grad, num_inputs=len(input_name_to_idx), + unpack_ivalues="\n".join(unpack_ivalues), compute_index_ranges=compute_index_ranges, saved_variables=saved_variables, release_variables=release_variables, @@ -1015,4 +1074,5 @@ def emit_derivative( compiled_args=compiled_args, apply_with_saved_before=apply_with_saved_before, apply_with_saved_after=apply_with_saved_after, + get_packed_args=get_packed_args, ) diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp index 5bc089f67df74..ba5cb3d912c5d 100644 --- a/tools/autograd/templates/Functions.cpp +++ b/tools/autograd/templates/Functions.cpp @@ -15,6 +15,30 @@ using at::TensorList; namespace torch::autograd::generated { +static at::IValue compute_output_metadata(const torch::autograd::edge_list& next_edges) { + auto output_metadata = torch::dynamo::autograd::IValuePacker< + std::vector>>::pack( + torch::dynamo::autograd::get_input_metadata(next_edges)); + return output_metadata; +} + +static C10_NOINLINE variable_list compiled_autograd_apply_functional( + const PackedArgs& packed_args, + const edge_list& next_edges, + SwapSavedVariables& saved, + const variable_list& grads, + const std::string& name) { + auto output_metadata = compute_output_metadata(next_edges); + const auto& pyinterface = torch::dynamo::autograd::getPyCompilerInterface(); + return pyinterface->call_function( + saved.get_py_compiler(), + "apply_functional", + name, + grads, + packed_args.vec(), + output_metadata); +} + ${autograd_function_definitions} } // namespace torch::autograd::generated diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index fb7017bc6dc9d..0627789a4bf1d 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -4,10 +4,11 @@ import itertools import operator import time -from collections import defaultdict +from collections import Counter, defaultdict from typing import Any, Optional, TYPE_CHECKING, Union import torch +import torch.utils._pytree as pytree from torch._dynamo.external_utils import ( call_backward, call_hook, @@ -65,6 +66,50 @@ def maybe_clone(x): return x +# We lazily bind "functional backward" variants for PyTorch built-in autograd +# nodes to this class. Example: torch._dynamo.compiled_autograd.ops.MulBackward0 +# Each "functional backward" is bound the first time the node's apply_with_saved +# function is called. It's possible to avoid lazy binding and instead bind +# all of this upfront (perhaps at import time) via codegen changes. +class OpNamespace: + def __init__(self): + self.custom_function_name_counter: Counter[str] = Counter() + + def add(self, name, fn, is_custom_function=False): + if is_custom_function: + name = "CppNode" + name + count = self.custom_function_name_counter[name] + self.custom_function_name_counter[name] += 1 + name = f"{name}{count}" + else: + assert not hasattr(self, name) + + result = Op(name, fn, is_custom_function) + torch._dynamo.allow_in_graph(result) + setattr(self, name, result) + return name + + def get(self, name): + return getattr(self, name) + + +class Op: + def __init__(self, name, fn, is_custom_function): + self.fn = fn + self.is_custom_function = is_custom_function + self.__name__ = name + self.__module__ = "torch._dynamo.compiled_autograd.ops" + + def __call__(self, *args, **kwargs): + return self.fn(*args, **kwargs) + + def __repr__(self): + return self.__module__ + "." + self.__name__ + + +ops = OpNamespace() + + _graph_placeholders = ["inputs", "sizes", "scalars", "hooks"] _impure_targets = OrderedSet( [ @@ -137,7 +182,8 @@ def begin_capture( self.fx_tracer.root = torch.nn.Module() self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer) self.fx_tracer.tensor_attrs = {} - args_proxy, sizes_proxy, scalars_proxy, self.hooks_proxy = ( + self.symnode_proxy_lookup = {} + args_proxy, self.sizes_proxy, self.scalars_proxy, self.hooks_proxy = ( self.fx_tracer.create_proxy("placeholder", name, (), {}) for name in _graph_placeholders ) @@ -160,7 +206,9 @@ def begin_capture( ) for idx, val in enumerate(sizes) ] - self.bind_tensors_to_proxies(sizes, sizes_proxy, sizes_origins) + self.bind_tensors_to_proxies(sizes, self.sizes_proxy, sizes_origins) + for i, symint in enumerate(sizes): + self.symnode_proxy_lookup[symint.node] = self.sizes_proxy[i] for idx, val in enumerate(scalars): source = self.source("scalars", idx) @@ -182,7 +230,9 @@ def begin_capture( ) else: raise AssertionError("Unexpected scalar type: ", type(val)) - self.bind_tensors_to_proxies(scalars, scalars_proxy, scalars_origins) + self.bind_tensors_to_proxies(scalars, self.scalars_proxy, scalars_origins) + for i, symval in enumerate(scalars): + self.symnode_proxy_lookup[symval.node] = self.scalars_proxy[i] # type: ignore[union-attr] # TODO(jansel): are all these modes needed? self.stack.enter_context(decompose({})) @@ -197,31 +247,203 @@ def begin_capture( ) return inputs, sizes, scalars - def proxy_call_backward( + def proxy_call_aot_backward( self, - inputs, - output_metadatas, + pinputs, + psaved_tensors, saved_tensors, - backward_idx: int, + pctx, + ctx, + maybe_backward_state_idx, ): - assert self.hooks_proxy is not None - backward_c_function = self.hooks_proxy[backward_idx] # type: ignore[index] - proxies = self.fx_tracer.create_proxy( + # The AOTBackward call consists of three things: the prologue, the + # backward graph, and the epilogue. + # Our strategy is: + # - allow_in_graph the prologue (in the CA graph and Dynamo graph), + # - copy-paste the backward graph into the CA graph so that CA passes and Dynamo can see it + # - trace directly through the epilogue. Anything that gets baked in is + # constant metadata (for example, metadata about the number of outputs, or removing + # RNG arguments or effect tokens). + # If Dynamo graph capture were better, then we could add a node for the prologue + # into the CA graph and have Dynamo trace into it. + + psymints = [self.to_proxy(e) for e in ctx._get_compiled_autograd_symints()] + + # NOTE: we should only close over constants + CompiledFunction = ctx._forward_cls + metadata = CompiledFunction.metadata + maybe_subclass_metadata = CompiledFunction.maybe_subclass_metadata + del CompiledFunction + + @torch._dynamo.allow_in_graph # type: ignore[misc] + def call_aot_bwd_prologue(ctx_saved_tensors, ctx_symints, *flat_args): + out = torch._functorch._aot_autograd.runtime_wrappers._backward_prologue_functional( + ctx_saved_tensors, + ctx_symints, + metadata, + maybe_subclass_metadata, + *flat_args, + ) + return out + + pgrads = self.fx_tracer.create_proxy( kind="call_function", - target=call_backward, + target=call_aot_bwd_prologue, args=( - backward_c_function, - self.to_proxy(saved_tensors), - *self.to_proxy(inputs), + psaved_tensors, + psymints, + *pinputs, ), kwargs={}, ) + pbackward_state = None + if maybe_backward_state_idx is not None: + pbackward_state = self.hooks_proxy[maybe_backward_state_idx] # type: ignore[index] + + # Copy-paste the AOT backward graph into the compiled autograd graph + def copy_paste_aot_backward_graph(): + def num_inputs(graph): + num_args = 0 + for node in graph.nodes: + if node.op == "placeholder": + num_args += 1 + continue + else: + break + return num_args + + # set up the proxy inputs to ctx._bw_module + # the calling convention is: [*symints, *args (primals and tangents), backward_state] + num_args = num_inputs(ctx._bw_module.graph) + pall_args = [ + pgrads[i] for i in range(num_args - int(pbackward_state is not None)) + ] + # replace the symints with our symints + symints = ctx._get_compiled_autograd_symints() + assert len(symints) == len(ctx.symints) + psymints = [self.to_proxy(e) for e in symints] + pall_args[: len(symints)] = psymints + # Add backward_state + if pbackward_state is not None: + pall_args.append(pbackward_state) + + # run over all nodes of the aot_backward graph. + # copy and paste them all into the compiled autograd graph. + args_idx = 0 + value_remap = {} + poutputs: Optional[list[torch.fx.Proxy]] = None + for node in ctx._bw_module.graph.nodes: + if node.op == "placeholder": + value_remap[node] = pall_args[args_idx].node + args_idx += 1 + elif node.op == "output": + assert len(node.args) == 1 + poutputs = [ + torch.fx.Proxy(value_remap[n], self.fx_tracer) + if isinstance(n, torch.fx.Node) + else n + for n in node.args[0] + ] + elif node.op == "get_attr": + name = node.target + qualname = self.fx_tracer.get_fresh_qualname(name) + setattr( + self.fx_tracer.root, qualname, getattr(ctx._bw_module, name) + ) + result = self.fx_tracer.create_node("get_attr", qualname, (), {}) + value_remap[node] = result + elif node.op == "call_function": + result = self.fx_tracer.graph.node_copy( + node, lambda n: value_remap[n] + ) + value_remap[node] = result + else: + raise AssertionError("shouldn't get here") + assert poutputs is not None + + # In general we don't know what the shapes of the outputs are, so allocate + # some dummy sizes for them. + def dummy(): + with disable_proxy_modes_tracing(): + return torch.zeros(0, 0, 0, 0, 123) + + outputs = [ + dummy() if isinstance(o, torch.fx.Proxy) else o for o in poutputs + ] + self.bind_tensors_to_proxies(outputs, poutputs) + return outputs + + outputs = copy_paste_aot_backward_graph() + + def proxy_subclass_constructor(subclass_meta, is_runtime, unwrapped_args): + @torch._dynamo.allow_in_graph + def make_subclass(*unwrapped_args): + return subclass_meta.creation_fn(unwrapped_args, is_runtime=is_runtime) + + punwrapped_args = pytree.tree_map(self.to_proxy, unwrapped_args) + + poutput = self.fx_tracer.create_proxy( + kind="call_function", + target=make_subclass, + args=tuple(punwrapped_args), + kwargs={}, + ) + + output = self.allocate_dummy() + self.bind_tensors_to_proxies([output], [poutput]) + return output + + results = torch._functorch._aot_autograd.runtime_wrappers._backward_epilogue_functional( + metadata, + maybe_subclass_metadata, + outputs, + make_subclass_override=proxy_subclass_constructor, + ) + presults = pytree.tree_map(self.to_proxy, results) + return presults + + def proxy_call_backward( + self, + inputs, + output_metadatas, + saved_tensors, + backward_idx: int, + ctx: torch.autograd.function.BackwardCFunction, + maybe_backward_state_idx: Optional[int], + ): + assert self.hooks_proxy is not None + pctx = self.hooks_proxy[backward_idx] # type: ignore[index] + pinputs = self.to_proxy(inputs) + psaved_tensors = self.to_proxy(saved_tensors) + if hasattr(ctx._forward_cls, "_aot_id"): # type: ignore[attr-defined] + # AOT backward + proxies = self.proxy_call_aot_backward( + pinputs, + psaved_tensors, + saved_tensors, + pctx, + ctx, + maybe_backward_state_idx, + ) + else: + proxies = self.fx_tracer.create_proxy( + kind="call_function", + target=call_backward, + args=( + pctx, + psaved_tensors, + *pinputs, + ), + kwargs={}, + ) + assert proxies is not None + with disable_proxy_modes_tracing(): # create fake Tensors grad_ins: list[Optional[torch.Tensor]] = [] - for output_metadata in output_metadatas: - if output_metadata is None: + for idx, output_metadata in enumerate(output_metadatas): + if output_metadata is None or proxies[idx] is None: grad_ins.append(None) continue @@ -232,6 +454,71 @@ def proxy_call_backward( self.bind_tensors_to_proxies(grad_ins, proxies) return tuple(grad_ins) + def call_copy_slices_prologue(self, inputs, base, view): + args = ( + inputs, + base.sizes(), + base.strides(), + base.storage_offset(), + view.sizes(), + view.strides(), + view.storage_offset(), + ) + return self.proxy_call(copy_slices_prologue, args, [None] * 3) + + def call_copy_slices_epilogue(self, needs_input_grad, result, res, grad_slice): + return self.proxy_call( + copy_slices_epilogue, + (needs_input_grad, result, res, grad_slice), + [None] * len(needs_input_grad), + ) + + def allocate_dummy(self): + with disable_proxy_modes_tracing(): + # Weird quantity so it's easy to grep + return torch.zeros([0, 123456789]) + + def bind_function(self, fn_name, fn, is_custom_function): + """Binds ops.fn_name = fn""" + return ops.add(fn_name, fn, is_custom_function) + + def apply_functional(self, fn_name, grads, args, output_metadata): + """Proxies a call to ops.fn_name(grads, *args) into the graph""" + op = ops.get(fn_name) + return self.proxy_call(op, (grads, *args), output_metadata) + + def proxy_call(self, fn, args, output_metadata): + """Proxies a call to fn(*args) into the graph""" + flat_args, _ = pytree.tree_flatten(args) + proxy_args = pytree.tree_map(lambda e: self.to_proxy(e), args) + proxy_out = self.fx_tracer.create_proxy( + "call_function", fn, args=proxy_args, kwargs={} + ) + result = [self.allocate_dummy() for _ in output_metadata] + self.bind_tensors_to_proxies(result, [proxy_out[i] for i in range(len(result))]) + return result + + def validate_outputs(self, _, outputs, args, output_metadata): + """Proxies a call to ops.validate_outputs(outputs, *args) into the graph""" + op = ops.get("validate_outputs") + proxy_args = pytree.tree_map(self.to_proxy, (outputs, *args)) + new_proxy_outputs = self.fx_tracer.create_proxy( + "call_function", op, args=proxy_args, kwargs={} + ) + assert len(output_metadata) == len(outputs) + self.bind_tensors_to_proxies(outputs, new_proxy_outputs) + return outputs + + def accumulate(self, old_var, new_var): + old_var_proxy = self.to_proxy(old_var) + new_var_proxy = self.to_proxy(new_var) + proxy_out = self.fx_tracer.create_proxy( + "call_function", torch.add, args=(old_var_proxy, new_var_proxy), kwargs={} + ) + result = self.allocate_dummy() + self.bind_tensors_to_proxies([result], [proxy_out]) + return result + def proxy_call_hook(self, hook, *args, **kwargs): return self.fx_tracer.create_proxy( "call_function", @@ -314,6 +601,7 @@ def move_graph_nodes_to_cuda(self, graph) -> list[int]: assert nodes[first_getitem_idx] == inputs_users[0] last_getitem_idx = first_getitem_idx + len(inputs_users) - 1 assert nodes[last_getitem_idx] == inputs_users[-1] + # getitem nodes on inputs for i, node in enumerate(inputs_users): if not has_cuda_inputs and node.meta["val"].device.type == "cuda": has_cuda_inputs = True @@ -323,9 +611,16 @@ def move_graph_nodes_to_cuda(self, graph) -> list[int]: is_scalar = len(node.meta["val"].size()) == 0 if is_cpu and is_scalar: node_users = list(node.users.keys()) + # We can only move the cpu scalar if it is not exposed to user code. if all( - isinstance(user.target, torch._ops.OpOverload) - and user.target.namespace in ("prims", "aten") + ( + isinstance(user.target, torch._ops.OpOverload) + and user.target.namespace in ("prims", "aten") + ) + or ( + isinstance(user.target, Op) + and not user.target.is_custom_function + ) for user in node_users ): # all users are prims/aten, can move safely @@ -335,6 +630,7 @@ def move_graph_nodes_to_cuda(self, graph) -> list[int]: # this is to handle the case where cudagraphs is enabled on a cpu-only graph if has_cuda_inputs: for node in to_move.values(): + verbose_log.debug("Moving node %s from cpu to cuda", node) node.meta["val"] = node.meta["val"].cuda() # return runtime indices we need to move to cuda @@ -368,7 +664,10 @@ def is_impure(node): or (node.op == "call_function" and node.target in _impure_targets) ) + before = len(self.fx_tracer.graph.nodes) self.fx_tracer.graph.eliminate_dead_code(is_impure) + after = len(self.fx_tracer.graph.nodes) + verbose_log.debug("DCE removed %d nodes", before - after) def end_capture(self, outputs): self.fx_tracer.create_proxy( @@ -384,6 +683,18 @@ def end_capture(self, outputs): (self.fx_tracer.create_arg(self.to_proxy(outputs)),), {}, ) + runtime_inputs_to_move: list[int] = [] + if snapshot_cudagraph_enabled(): + runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph) + + # We traced using dummy tensors. Delete all the metadata of the dummy tensors. + # It's probably better to refactor this class to use a different tracer + # than the make_fx tracer, but that is a larger change. + for node in self.fx_tracer.graph.nodes: + for field in ["tensor_meta", "example_value", "val"]: + if field in node.meta: + del node.meta[field] + self.rename_aot_dispatcher_nodes() self.reorder_tensor_pre_hook_nodes() self.reorder_pre_hook_nodes_to_schedule_asap() @@ -402,9 +713,6 @@ def end_capture(self, outputs): # Proper fix is Richard's Python compiled autograd effort which will avoid calling make_fx and # should prevent these ops from going into the CA graph. self.dce() - runtime_inputs_to_move: list[int] = [] - if snapshot_cudagraph_enabled(): - runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph) graph = GraphModule( self.fx_tracer.root, self.fx_tracer.graph, f"CompiledAutograd{self.id}" @@ -778,8 +1086,11 @@ def to_proxy(self, t): return [self.to_proxy(x) for x in t] if isinstance(t, tuple): return tuple(self.to_proxy(x) for x in t) - # can it be torch.SymInt as the code used to imply? - assert isinstance(t, torch.Tensor) + if isinstance(t, (torch.SymInt, torch.SymFloat)): + return self.symnode_proxy_lookup[t.node] + if not isinstance(t, torch.Tensor): + # constant types like device, dtype, str + return t proxy_tensor = fetch_object_proxy(self.fx_tracer, t) assert isinstance(proxy_tensor, torch.fx.experimental.proxy_tensor._ProxyTensor) return proxy_tensor.proxy @@ -921,3 +1232,39 @@ def reset() -> None: torch._C._dynamo.compiled_autograd.clear_cache() global COMPILE_COUNTER COMPILE_COUNTER = itertools.count() + + +# Reimplementation of part of CopySlices::apply in Python. +# The shared code is really similar so we're not going to try to deduplicate. +def copy_slices_prologue( + inputs, + base_sizes, + base_strides, + base_storage_offset, + view_sizes, + view_strides, + view_storage_offset, +): + grad = inputs[0] + result = grad.new_empty_strided(base_sizes, base_strides) + assert grad is not None + result.copy_(grad) + offset = view_storage_offset - base_storage_offset + grad_slice = result.as_strided(view_sizes, view_strides, offset) + return [result, grad_slice, grad_slice.clone(memory_format=torch.contiguous_format)] + + +# Reimplementation of part of CopySlices::apply in Python. +# The shared code is really similar so we're not going to try to deduplicate. +def copy_slices_epilogue(needs_input_grad, result, res, grad_slice): + grad_inputs = [None] * len(needs_input_grad) + for i in range(len(needs_input_grad)): + if needs_input_grad[i]: + if res[i] is None: + continue + if i == 0: + grad_slice.copy_(res[i]) + grad_inputs[i] = result + else: + grad_inputs[i] = res[i] + return grad_inputs diff --git a/torch/_dynamo/external_utils.py b/torch/_dynamo/external_utils.py index c50df4d969693..ec17a6b6d1354 100644 --- a/torch/_dynamo/external_utils.py +++ b/torch/_dynamo/external_utils.py @@ -116,6 +116,14 @@ def call_backward( return grads +def normalize_as_list(x: Any) -> list[Any]: + if isinstance(x, tuple): + return list(x) + elif isinstance(x, list): + return x + return [x] + + def untyped_storage_size(x: torch.Tensor) -> int: return x.untyped_storage().size() diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 02290ba55fc27..3ef4d94a1385c 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -72,6 +72,8 @@ def radians(x): def accumulate_grad(x, new_grad): + if new_grad is None: + return new_grad = torch.clone(new_grad) if x.grad is None: x.grad = new_grad diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 9e5b048d14947..ef634af4eb082 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3276,6 +3276,7 @@ def _module_dir(m: types.ModuleType): MOD_INLINELIST = [ "torch._decomp", "torch._dynamo._trace_wrapped_higher_order_op", + "torch._dynamo.compiled_autograd", "torch._dynamo.comptime", "torch._dynamo.polyfills", "torch._functorch._aot_autograd.subclass_parametrization", diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index 954865c76b19d..dc9e5af16da95 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -62,7 +62,6 @@ from .utils import ( call_func_at_runtime_with_args, make_boxed_func, - normalize_as_list, partial_flatten_asdict, strict_zip, ) @@ -1683,7 +1682,9 @@ def _backward_prologue_functional( # NOTE: this function must be torch._dynamo.allow_in_graph-able. Non tensor/symnode inputs must be constants. -def _backward_epilogue_functional(metadata, maybe_subclass_metadata, out): +def _backward_epilogue_functional( + metadata, maybe_subclass_metadata, out, *, make_subclass_override=None +): # Toss out the backward output tokens num_bw_tokens = metadata.num_backward_tokens if num_bw_tokens > 0: @@ -1703,6 +1704,7 @@ def _backward_epilogue_functional(metadata, maybe_subclass_metadata, out): subclass_metas=maybe_subclass_metadata.grad_input_metas, included_subclass_symints=True, is_runtime=True, + make_subclass_override=make_subclass_override, ) return outs_wrapped return out @@ -1728,6 +1730,13 @@ def process_runtime_tangent(x, meta: Union[PlainTensorMeta, SubclassCreationMeta expected_meta = meta.meta runtime_type = type(x) + if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + # When we're inside compiled autograd's AOTDispatcher step, + # regular Tensors look like FunctionalTensors. + # Tensor subclasses still look like Tensor subclasses though. + if isinstance(x, torch._subclasses.functional_tensor.FunctionalTensor): + runtime_type = torch.Tensor + runtime_meta = None runtime_subclass_keys: Sequence[str] = [] @@ -2001,23 +2010,9 @@ def backward(double_ctx, *args): @staticmethod def _backward_impl(ctx, all_args): - if ctx._is_compiled_autograd_tracing(): - if lazy_backward_info is None: - raise RuntimeError( - """This compiled backward function was saved by AOTAutogradCache, which does not support - compiled autograd. Please turn off AOTAutogradCache using `TORCHINDUCTOR_AUTOGRAD_CACHE=0`.""" - ) - bw_module = lazy_backward_info.bw_module - # For compiled autograd, run raw FX graph so that it can be inlined into the larger graph - symints = ctx._get_compiled_autograd_symints() - assert len(symints) == len(ctx.symints) - all_args[: len(symints)] = symints - if backward_state_indices: - assert ctx._compiled_autograd_backward_state.proxy is not None - all_args.append(ctx._compiled_autograd_backward_state) - context = torch._C._DisableAutocast if disable_amp else nullcontext - with context(): - return normalize_as_list(bw_module(*all_args)) + assert ( + not ctx._is_compiled_autograd_tracing() + ), "compiled autograd reimplements this function at proxy_call_aot_backward" assert ( not backward_state_indices diff --git a/torch/_functorch/_aot_autograd/subclass_utils.py b/torch/_functorch/_aot_autograd/subclass_utils.py index 2ff94f344c63f..d352f43da371b 100644 --- a/torch/_functorch/_aot_autograd/subclass_utils.py +++ b/torch/_functorch/_aot_autograd/subclass_utils.py @@ -8,7 +8,7 @@ import collections import typing from collections.abc import Iterable -from typing import Any, Optional, TypeVar, Union +from typing import Any, Callable, Optional, TypeVar, Union import torch import torch.utils._pytree as pytree @@ -326,6 +326,7 @@ def wrap_tensor_subclasses( num_fw_outs_saved_for_bw: Optional[int] = None, included_subclass_symints: bool = False, is_runtime: bool = False, + make_subclass_override: Optional[Callable] = None, ) -> tuple[Any, ...]: wrapped_args = [] num_args_tallied = 0 @@ -336,9 +337,15 @@ def wrap_tensor_subclasses( else: assert isinstance(subclass_meta, SubclassCreationMeta) assert subclass_meta.included_subclass_symints == included_subclass_symints - wrapped_args.append( - subclass_meta.creation_fn(unwrapped_args, is_runtime=is_runtime) - ) + + if make_subclass_override: + wrapped_args.append( + make_subclass_override(subclass_meta, is_runtime, unwrapped_args) + ) + else: + wrapped_args.append( + subclass_meta.creation_fn(unwrapped_args, is_runtime=is_runtime) + ) num_args_tallied += subclass_meta.arg_count # Note: [Partitioner handling for Subclasses, Part 2] diff --git a/torch/autograd/function.py b/torch/autograd/function.py index ca7d70192ca57..1bcb25751454b 100644 --- a/torch/autograd/function.py +++ b/torch/autograd/function.py @@ -334,6 +334,9 @@ def __init__(cls, name, bases, attrs): backward_fn._compiled_autograd_should_lift = attrs.get( # type: ignore[attr-defined] "_compiled_autograd_should_lift", True ) + backward_fn._bw_module = None # type: ignore[attr-defined] + if getattr(cls, "_lazy_backward_info", None): + backward_fn._bw_module = cls._lazy_backward_info.bw_module # type: ignore[attr-defined] cls._backward_cls = backward_fn super().__init__(name, bases, attrs) diff --git a/torch/csrc/autograd/custom_function.cpp b/torch/csrc/autograd/custom_function.cpp index 48e79f9ec5044..1df7efb35d10d 100644 --- a/torch/csrc/autograd/custom_function.cpp +++ b/torch/csrc/autograd/custom_function.cpp @@ -503,6 +503,16 @@ void check_variable_result( } } +AutogradContext::AutogradContext(PackedArgs& packed_args) { + saved_data = packed_args.unpack_saved_data(); + saved_variables_override_ = packed_args.unpack(); + // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer) + materialize_grads_ = packed_args.unpack(); + // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer) + has_freed_buffers_ = packed_args.unpack(); + needs_input_grad_override_ = packed_args.unpack>(); +} + void AutogradContext::save_for_backward(variable_list to_save) { to_save_ = std::move(to_save); } @@ -527,6 +537,9 @@ void AutogradContext::save_variables() { variable_list AutogradContext::get_saved_variables() const { TORCH_CHECK(!has_freed_buffers_, ERR_BACKWARD_TWICE); + if (saved_variables_override_.has_value()) { + return *saved_variables_override_; + } variable_list saved; saved.reserve(saved_variables_.size()); auto ptr = grad_fn_.lock(); @@ -538,6 +551,9 @@ variable_list AutogradContext::get_saved_variables() const { } bool AutogradContext::needs_input_grad(size_t output_edge_index) const { + if (needs_input_grad_override_.has_value()) { + return needs_input_grad_override_.value().at(output_edge_index); + } auto ptr = grad_fn_.lock(); TORCH_INTERNAL_ASSERT(ptr); return ptr->task_should_compute_output(output_edge_index); @@ -545,6 +561,15 @@ bool AutogradContext::needs_input_grad(size_t output_edge_index) const { bool AutogradContext::needs_input_grad( std::initializer_list idxs) const { + if (needs_input_grad_override_.has_value()) { + return std::any_of(idxs.begin(), idxs.end(), [this](IndexRange range) { + bool result = false; + for (const auto i : c10::irange(range.first, range.second)) { + result |= needs_input_grad_override_.value().at(i); + } + return result; + }); + } auto ptr = grad_fn_.lock(); TORCH_INTERNAL_ASSERT(ptr); return ptr->task_should_compute_output(idxs); diff --git a/torch/csrc/autograd/custom_function.h b/torch/csrc/autograd/custom_function.h index a142f082e831e..34f508ac72abe 100644 --- a/torch/csrc/autograd/custom_function.h +++ b/torch/csrc/autograd/custom_function.h @@ -126,6 +126,8 @@ struct TORCH_API AutogradContext { AutogradContext& operator=(AutogradContext&& other) = delete; ~AutogradContext() = default; + AutogradContext(PackedArgs& packed_args); + /// Can be used to save non-variable data for `backward`. ska::flat_hash_map saved_data; @@ -169,12 +171,103 @@ struct TORCH_API AutogradContext { std::weak_ptr grad_fn_; bool has_freed_buffers_{false}; + // Compiled autograd overrides saved_variables() and needs_input_grad(). + // We store the values we want to return here. + std::optional saved_variables_override_; + std::optional> needs_input_grad_override_; + void save_variables(); template friend struct CppNode; + template + friend variable_list CppNode_apply_functional( + variable_list&& inputs, + AutogradContext& ctx_, + const std::vector& is_variable_input_, + const std::vector& output_info_, + const std::string& name); }; +template +inline variable_list CppNode_apply_functional( + // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) + variable_list&& inputs, + AutogradContext& ctx_, + const std::vector& is_variable_input_, + const std::vector& output_info_, + const std::string& name) { + at::OptionalDeviceGuard _device_guard; + + auto num_inputs = inputs.size(); + variable_list backward_inputs; + backward_inputs.reserve(num_inputs); + for (const auto i : c10::irange(num_inputs)) { + if (inputs[i].defined() || !ctx_.materialize_grads_) { + backward_inputs.emplace_back(std::move(inputs[i])); + } else { + backward_inputs.emplace_back(output_info_[i].zeros(_device_guard)); + } + } + + auto outputs = T::backward(&ctx_, backward_inputs); + + const auto num_forward_inputs = + static_cast(is_variable_input_.size()); + auto num_outputs = static_cast(outputs.size()); + // Returning too many results is ok, but only as long as they're all + // undefined. Truncate the result vector in that case. + if (num_outputs > num_forward_inputs) { + bool all_undef = true; + for (const auto i : c10::irange(num_forward_inputs, num_outputs)) { + all_undef &= (!outputs[i].defined()); + } + if (all_undef) { + outputs.resize(num_forward_inputs); + num_outputs = num_forward_inputs; + } + } + + if (num_outputs != num_forward_inputs) { + std::string msg("function "); + msg += name + " returned an incorrect number of gradients (expected "; + msg += std::to_string(num_forward_inputs) + ", got "; + msg += std::to_string(num_outputs) + ")"; + throw std::runtime_error(msg); + } + + variable_list results; + results.reserve(num_outputs); + for (const auto i : c10::irange(num_outputs)) { + if (!is_variable_input_[i]) { + if (outputs[i].defined()) { + std::string msg("function "); + msg += name + + " returned a gradient different that is defined at position "; + msg += std::to_string(i + 1) + + ", std the corresponding forward input was not a Variable"; + throw std::runtime_error(msg); + } + continue; + } + results.emplace_back(outputs[i]); + } + return results; +} + +template +inline variable_list CppNode_apply_functional_ivalue( + const variable_list& inputs, + const std::vector& args) { + auto packed_args = PackedArgs(args); + auto ctx = AutogradContext(packed_args); + auto output_info = packed_args.unpack>(); + auto is_variable_input = packed_args.unpack>(); + auto name = packed_args.unpack(); + return CppNode_apply_functional( + variable_list(inputs), ctx, is_variable_input, output_info, name); +} + // CppNode is the Node in the autograd graph that represents the user defined // backward function for Function. Calls to CppNode::apply are forward to // T::backward(). @@ -232,7 +325,64 @@ struct CppNode : public Node { saved.before(ctx_.has_freed_buffers_); saved.before(input_info_); saved.before(output_info_); - auto results = apply(variable_list(inputs)); + + PackedArgs packed_args; + packed_args.pack_saved_data(ctx_.saved_data); + variable_list saved_variables = ctx_.get_saved_variables(); + packed_args.pack(saved_variables); + packed_args.pack(ctx_.materialize_grads_); + packed_args.pack(ctx_.has_freed_buffers_); + + std::vector needs_input_grad; + { + auto ptr = ctx_.grad_fn_.lock(); + TORCH_INTERNAL_ASSERT(ptr); + for (const auto i : c10::irange(ptr->next_edges().size())) { + needs_input_grad.push_back(ptr->task_should_compute_output(i)); + } + } + packed_args.pack(needs_input_grad); + + packed_args.pack(output_info_); + packed_args.pack(is_variable_input_); + packed_args.pack(name()); + auto args = std::move(packed_args).vec(); + + auto output_metadata = torch::dynamo::autograd:: + IValuePacker>>::pack( + torch::dynamo::autograd::get_input_metadata(next_edges())); + + const auto& pyinterface = torch::dynamo::autograd::getPyCompilerInterface(); + + // Each time apply_with_saved is called, we bind a new function to Python. + // This is because the schema might be different on compiled autograd cache + // misses. An alternative is to pass the schema to Python so that it can be + // an input to a function, but the schema can't be put into an FX graph + // right now. + std::vector schema; + schema.reserve(args.size()); + for (const auto& ivalue : args) { + if (ivalue.isTensor()) { + schema.emplace_back(at::TensorType::get()); + } else { + schema.emplace_back(ivalue.type()); + } + } + auto fn_name = pyinterface->bind_function( + saved.get_py_compiler(), + std::string(typeid(T).name()), + CppNode_apply_functional_ivalue, + schema, + /*is_custom_function*/ true); + + auto results = pyinterface->call_function( + saved.get_py_compiler(), + "apply_functional", + fn_name, + inputs, + args, + output_metadata); + saved.after(ctx_.saved_data); TORCH_INTERNAL_ASSERT(ctx_.non_differentiable_.empty()); TORCH_INTERNAL_ASSERT(ctx_.dirty_inputs_.empty()); @@ -403,68 +553,13 @@ auto Function::apply(Args&&... args) template // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) variable_list CppNode::apply(variable_list&& inputs) { - at::OptionalDeviceGuard _device_guard; - - auto num_inputs = inputs.size(); - variable_list backward_inputs; - backward_inputs.reserve(num_inputs); - for (const auto i : c10::irange(num_inputs)) { - if (inputs[i].defined() || !ctx_.materialize_grads_) { - backward_inputs.emplace_back(std::move(inputs[i])); - } else { - backward_inputs.emplace_back(output_info_[i].zeros(_device_guard)); - } - } - // Acquire lock to here protect thread safety on custom C++ Autograd Node // This is needed for the custom Autograd Node since we don't know if the // user defined Node will write to the shared data during backward. // see Note [Thread Safety on Autograd Node] std::lock_guard lock(mutex_); - - auto outputs = T::backward(&ctx_, backward_inputs); - - const auto num_forward_inputs = - static_cast(is_variable_input_.size()); - auto num_outputs = static_cast(outputs.size()); - // Returning too many results is ok, but only as long as they're all - // undefined. Truncate the result vector in that case. - if (num_outputs > num_forward_inputs) { - bool all_undef = true; - for (const auto i : c10::irange(num_forward_inputs, num_outputs)) { - all_undef &= (!outputs[i].defined()); - } - if (all_undef) { - outputs.resize(num_forward_inputs); - num_outputs = num_forward_inputs; - } - } - - if (num_outputs != num_forward_inputs) { - std::string msg("function "); - msg += name() + " returned an incorrect number of gradients (expected "; - msg += std::to_string(num_forward_inputs) + ", got "; - msg += std::to_string(num_outputs) + ")"; - throw std::runtime_error(msg); - } - - variable_list results; - results.reserve(num_outputs); - for (const auto i : c10::irange(num_outputs)) { - if (!is_variable_input_[i]) { - if (outputs[i].defined()) { - std::string msg("function "); - msg += name() + - " returned a gradient different that is defined at position "; - msg += std::to_string(i + 1) + - ", std the corresponding forward input was not a Variable"; - throw std::runtime_error(msg); - } - continue; - } - results.emplace_back(outputs[i]); - } - return results; + return CppNode_apply_functional( + std::move(inputs), ctx_, is_variable_input_, output_info_, name()); } template diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 2aad75e0e74b9..c8d465211a6f8 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -897,6 +897,19 @@ bool has_input_metadata(const Edge& thing) { return thing.is_valid(); } +std::vector> collect_input_metadata( + const edge_list& edges) { + std::vector> input_metadata; + for (const auto& edge : edges) { + if (!edge.is_valid()) { + input_metadata.emplace_back(std::nullopt); + continue; + } + input_metadata.emplace_back(edge.function->input_metadata(edge.input_nr)); + } + return input_metadata; +} + // Given an vector or vector>, validate the // outputs. This involves using the InputMetadata to check the outputs and also // potentially calling .sum_to on the outputs. diff --git a/torch/csrc/autograd/engine.h b/torch/csrc/autograd/engine.h index 4243f1b1d6ee5..5bf00bac5378e 100644 --- a/torch/csrc/autograd/engine.h +++ b/torch/csrc/autograd/engine.h @@ -47,6 +47,8 @@ TORCH_API void validate_outputs( const std::vector>& input_metadata, variable_list& grads, const std::function& format_error); +TORCH_API std::vector> collect_input_metadata( + const edge_list& edges); struct NodeTask { std::weak_ptr base_; diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index ba2f6edbc6c0c..abd11303eafe8 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -34,8 +34,12 @@ using tensor_list = std::vector; using variable_list = std::vector; using edge_list = std::vector; using saved_variable_list = std::vector; +using ivalue_list = std::vector; +using functional_apply_t = std::function< + variable_list(const variable_list&, const std::vector&)>; using IndexRange = std::pair; using torch::dynamo::autograd::CompiledNodeArgs; +using torch::dynamo::autograd::PackedArgs; using torch::dynamo::autograd::SwapSavedVariables; // Custom deleter to prevent stack overflows. @@ -604,6 +608,12 @@ struct TORCH_API Node : std::enable_shared_from_this { std::string("apply_with_saved not implemented: ") + name()); } + // If this node is the AOTBackward node produced by torch.compile. + // Compiled Autograd special-cases on this information. + virtual bool is_aot_backward() const { + return false; + } + protected: /// Performs the `Node`'s actual operation. virtual variable_list apply(variable_list&& inputs) = 0; diff --git a/torch/csrc/autograd/function_hook.h b/torch/csrc/autograd/function_hook.h index 6342bf280a5ce..4e8bba79a169b 100644 --- a/torch/csrc/autograd/function_hook.h +++ b/torch/csrc/autograd/function_hook.h @@ -8,6 +8,7 @@ namespace torch::dynamo::autograd { class CompiledNodeArgs; class SwapSavedVariables; +struct PackedArgs; } // namespace torch::dynamo::autograd // A hook that's called on gradients diff --git a/torch/csrc/autograd/functions/tensor.cpp b/torch/csrc/autograd/functions/tensor.cpp index 1c13b60ca7831..a06ebefb85c3c 100644 --- a/torch/csrc/autograd/functions/tensor.cpp +++ b/torch/csrc/autograd/functions/tensor.cpp @@ -16,6 +16,8 @@ namespace torch::autograd { +using torch::dynamo::autograd::IValuePacker; + static variable_list CopyBackwards_apply_functional( variable_list&& grads, std::array needs_input_grad, @@ -41,6 +43,16 @@ static variable_list CopyBackwards_apply_functional( return grad_inputs; } +static variable_list CopyBackwards_apply_functional_ivalue( + const variable_list& grads, + const ivalue_list& args) { + PackedArgs r(args); + auto needs_input_grad = r.unpack>(); + auto src_options = r.unpack(); + return CopyBackwards_apply_functional( + variable_list(grads), needs_input_grad, src_options); +} + auto CopyBackwards::apply(variable_list&& grads) -> variable_list { return CopyBackwards_apply_functional( std::move(grads), @@ -51,11 +63,43 @@ auto CopyBackwards::apply(variable_list&& grads) -> variable_list { void CopyBackwards::compiled_args(CompiledNodeArgs& args) { args.collect(src_options); } + variable_list CopyBackwards::apply_with_saved( const variable_list& inputs, SwapSavedVariables& saved) { saved.before(src_options); - auto result = apply(variable_list(inputs)); + + static c10::once_flag flag; + c10::call_once(flag, [&]() { + std::vector schema = { + IValuePacker>::packed_type(), + IValuePacker::packed_type()}; + const auto& interface = torch::dynamo::autograd::getPyCompilerInterface(); + interface->bind_function( + saved.get_py_compiler(), + name(), + CopyBackwards_apply_functional_ivalue, + schema); + }); + + PackedArgs packed_args; + packed_args.pack>( + {task_should_compute_output(0), task_should_compute_output(1)}); + packed_args.pack(src_options); + + auto output_metadata = torch::dynamo::autograd:: + IValuePacker>>::pack( + torch::dynamo::autograd::get_input_metadata(next_edges())); + + const auto& interface = torch::dynamo::autograd::getPyCompilerInterface(); + auto result = interface->call_function( + saved.get_py_compiler(), + "apply_functional", + name(), + inputs, + std::move(packed_args).vec(), + output_metadata); + saved.after(src_options); return result; } @@ -80,38 +124,7 @@ CopySlices::CopySlices( } } -// common code between apply/apply_with_saved -template -inline variable_list CopySlices::apply_impl( - variable_list&& inputs, - const T& call_fn) { - check_input_variables("CopySlices", inputs, 1, -1, true); - auto& grad = std::move(inputs)[0]; - if (!grad.defined()) { - return variable_list(num_outputs()); - } - - // Acquire lock to here protect thread safety on fn - // see Note [Thread Safety on Autograd Node] - std::lock_guard lock(mutex_); - - if (!fn) { - throw std::runtime_error(ERR_BACKWARD_TWICE); - } - - auto result = - grad.new_empty_strided_symint(base.sym_sizes(), base.sym_strides()); - result.copy_(grad); - - at::Tensor grad_slice; - if (view_fn) { - grad_slice = (*view_fn)(result); - } else { - auto offset = view.sym_storage_offset() - base.sym_storage_offset(); - grad_slice = - result.as_strided_symint(view.sym_sizes(), view.sym_strides(), offset); - } - +void CopySlices::update_exec_info() { // See Note [View + Inplace update for view tensor] For more details on this // block Since the gradient edge for the 0th input is different between `this` // and `fn`, make sure that the one from `fn` has the same metadata in the @@ -154,6 +167,41 @@ inline variable_list CopySlices::apply_impl( TORCH_INTERNAL_ASSERT( fn->next_edge(i).function.get() == this->next_edge(i).function.get()); } +} + +// common code between apply/apply_with_saved +template +inline variable_list CopySlices::apply_impl( + variable_list&& inputs, + const T& call_fn) { + check_input_variables("CopySlices", inputs, 1, -1, true); + auto& grad = std::move(inputs)[0]; + if (!grad.defined()) { + return variable_list(num_outputs()); + } + + // Acquire lock to here protect thread safety on fn + // see Note [Thread Safety on Autograd Node] + std::lock_guard lock(mutex_); + + if (!fn) { + throw std::runtime_error(ERR_BACKWARD_TWICE); + } + + auto result = + grad.new_empty_strided_symint(base.sym_sizes(), base.sym_strides()); + result.copy_(grad); + + at::Tensor grad_slice; + if (view_fn) { + grad_slice = (*view_fn)(result); + } else { + auto offset = view.sym_storage_offset() - base.sym_storage_offset(); + grad_slice = + result.as_strided_symint(view.sym_sizes(), view.sym_strides(), offset); + } + + update_exec_info(); // TODO: We clone grad_slice because we modify it below and "fn" might save // it for the backward of res. We might be able to avoid the clone() if @@ -201,17 +249,38 @@ variable_list CopySlices::apply_with_saved( SwapSavedVariables& saved) { saved.before(base); saved.before(view); - int call_count = 0; - variable_list result = apply_impl( - variable_list(grads), - [this, &saved, &call_count](const variable_list& inputs2) { - call_count++; - return fn->apply_with_saved(inputs2, saved); - }); - TORCH_INTERNAL_ASSERT(call_count == 1); + + auto results = variable_list(num_outputs()); + if (grads[0].defined()) { + if (!fn) { + throw std::runtime_error(ERR_BACKWARD_TWICE); + } + update_exec_info(); + + std::vector needs_input_grad; + for (const auto i : c10::irange(num_outputs())) { + needs_input_grad.emplace_back(task_should_compute_output(i)); + } + // Not yet supported, also doesn't happen in typical eager mode execution + // (this only happens by default with torch-xla). + TORCH_INTERNAL_ASSERT(!view_fn); + const auto& interface = torch::dynamo::autograd::getPyCompilerInterface(); + variable_list stuff = interface->call_copy_slices_prologue( + saved.get_py_compiler(), grads, base, view); + TORCH_INTERNAL_ASSERT(stuff.size() == 3); + // These variables are named the same as in CopySlices::apply_impl. + // Follow along there. + auto result = stuff[0]; + auto grad_slice = stuff[1]; + auto grad_slice_clone = stuff[2]; + auto res = fn->apply_with_saved({grad_slice_clone}, saved); + results = interface->call_copy_slices_epilogue( + saved.get_py_compiler(), needs_input_grad, result, res, grad_slice); + } + saved.after(base); saved.after(view); - return result; + return results; } auto CopySlices::apply(variable_list&& inputs1) -> variable_list { diff --git a/torch/csrc/autograd/functions/tensor.h b/torch/csrc/autograd/functions/tensor.h index e812860e3adc1..4b0c2190ed542 100644 --- a/torch/csrc/autograd/functions/tensor.h +++ b/torch/csrc/autograd/functions/tensor.h @@ -172,6 +172,7 @@ struct TORCH_API CopySlices : public Node { variable_list apply_with_saved( const variable_list& inputs, SwapSavedVariables& saved) override; + void update_exec_info(); at::TensorGeometry base; // view and view_fn are redundant and view_fn will be used if available. diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 1c1d4f504af23..9a8375480374d 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -131,6 +131,11 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { if (!ParameterClass) return nullptr; + py::class_(m, "TensorGeometry") + .def("sizes", &at::TensorGeometry::sizes) + .def("strides", &at::TensorGeometry::strides) + .def("storage_offset", &at::TensorGeometry::storage_offset); + py::class_(m, "ProfilerEvent") .def("kind", &LegacyEvent::kindStr) .def("name", [](const LegacyEvent& e) { return e.name(); }) diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index 19151cbaafe63..dd0b7a927bfee 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -30,6 +30,7 @@ #include #include #include +#include #include #include @@ -237,16 +238,23 @@ auto PyNode::defer_to_dynamo( TORCH_INTERNAL_ASSERT( _backward_idx.has_value(), "indices should already be set by compiled_args, called before apply_with_saved"); - TORCH_INTERNAL_ASSERT(!_backward_state_idx.has_value()); + PyObject* backward_state_idx = Py_None; + if (_backward_state_idx.has_value()) { + backward_state_idx = THPUtils_packInt64(_backward_state_idx.value()); + // this might be simplifiable now that we no longer inline + Py_CLEAR(py_fn->compiled_autograd_backward_state); + } THPObjectPtr r(PyObject_CallMethod( // NOLINTNEXTLINE(bugprone-unchecked-optional-access) compiler.value(), "proxy_call_backward", - "OOOi", + "OOOiOO", pyInputs.get(), fwdInputMetadatas.get(), saved_tensors.get(), - *_backward_idx)); + *_backward_idx, + obj, + backward_state_idx)); if (!r) throw_python_error(); @@ -288,6 +296,11 @@ auto PyNode::name() const -> std::string { return name; } +bool PyNode::is_aot_backward() const { + py::handle handle(obj); + return py::hasattr(py::getattr(handle, "_forward_cls"), "_aot_id"); +} + auto PyNode::compiled_autograd_should_lift() const -> bool { pybind11::gil_scoped_acquire gil; static PyObject* attr_name = @@ -340,11 +353,8 @@ void PyNode::compiled_args(CompiledNodeArgs& args) { args.collect(f->output_info); args.collect(f->input_info); - if (compiled_autograd_should_lift()) { - Py_INCREF(obj); - _backward_idx = - args.add_backward(c10::SafePyObject(obj, getPyInterpreter())); - } + Py_INCREF(obj); + _backward_idx = args.add_backward(c10::SafePyObject(obj, getPyInterpreter())); PyObject* bw_state = f->compiled_autograd_backward_state; if (args.cond(bw_state != nullptr)) { @@ -366,28 +376,8 @@ variable_list PyNode::apply_with_saved( saved.before(f->output_info); saved.before(f->input_info); f->compiled_autograd_tracing = true; - variable_list result; - if (!compiled_autograd_should_lift()) { - if (_backward_state_idx.has_value()) { - PyObject* r = PyObject_CallMethod( - saved.get_py_compiler(), - "bind_backward_state", - "i", - *_backward_state_idx); - if (r == nullptr) { - throw python_error(); - } - THPObjectPtr prior(f->compiled_autograd_backward_state); - f->compiled_autograd_backward_state = r; - result = apply(variable_list(inputs)); - Py_CLEAR(f->compiled_autograd_backward_state); - f->compiled_autograd_backward_state = prior.release(); - } else { - result = apply(variable_list(inputs)); - } - } else { - result = defer_to_dynamo(variable_list(inputs), saved.get_py_compiler()); - } + variable_list result = + defer_to_dynamo(variable_list(inputs), saved.get_py_compiler()); f->compiled_autograd_tracing = false; saved.after(f->compiled_autograd_symints); saved.after(f->saved_variables); @@ -1092,6 +1082,7 @@ PyObject* process_outputs( THPFunction* grad_fn, const UnpackedInput& unpacked, PyObject* inputs, + // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) THPObjectPtr&& raw_output, bool is_executable, torch::jit::Node* node, diff --git a/torch/csrc/autograd/python_function.h b/torch/csrc/autograd/python_function.h index 46faff8e46865..2f28c765ab069 100644 --- a/torch/csrc/autograd/python_function.h +++ b/torch/csrc/autograd/python_function.h @@ -43,6 +43,8 @@ struct PyNode : public Node { std::string name() const override; bool is_traceable() override; + bool is_aot_backward() const override; + void compiled_args(CompiledNodeArgs& args) override; variable_list apply_with_saved( const variable_list& inputs, diff --git a/torch/csrc/dynamo/compiled_autograd.cpp b/torch/csrc/dynamo/compiled_autograd.cpp new file mode 100644 index 0000000000000..7e2aad576189f --- /dev/null +++ b/torch/csrc/dynamo/compiled_autograd.cpp @@ -0,0 +1,27 @@ +#include +#include + +namespace torch::dynamo::autograd { + +std::unique_ptr kPyCompilerInterface; + +const std::unique_ptr& getPyCompilerInterface() { + TORCH_INTERNAL_ASSERT(kPyCompilerInterface != nullptr); + return kPyCompilerInterface; +} + +void setPyCompilerInterface(std::unique_ptr&& impl) { + TORCH_INTERNAL_ASSERT(impl != nullptr); + kPyCompilerInterface = std::move(impl); +} + +void resetPyCompilerInterface() { + kPyCompilerInterface.reset(); +} + +std::vector> get_input_metadata( + const edge_list& edges) { + return torch::autograd::collect_input_metadata(edges); +} + +} // namespace torch::dynamo::autograd diff --git a/torch/csrc/dynamo/compiled_autograd.h b/torch/csrc/dynamo/compiled_autograd.h index 383cff14b8e05..b70576ef03129 100644 --- a/torch/csrc/dynamo/compiled_autograd.h +++ b/torch/csrc/dynamo/compiled_autograd.h @@ -900,6 +900,542 @@ class SwapSavedVariables { StashedVars stashed_ivalues; }; +// NOTE: [Compiled Autograd and backward functions] +// Built-in autograd nodes have functional apply variants +// (e.g. MulBackward0_apply_functional). Compiled Autograd's initial graph +// capture wants to take a variant of this function and proxy it into the graph. +// Every autograd node defines an apply_with_saved function, that when invoked, +// proxys a call to a function into the Compiled Autograd graph. +// +// Some requirements that we have are: +// - The proxy'ed function must have inputs that are FX-graphable types. +// - Windows has a DLL symbol limit of 65536. +// - Node::apply_with_saved is in libtorch_cpu which does not have direct access +// to Python +// +// There were multiple ways to skin the cat, but what we end up doing is: +// - for e.g. MulBackward0_apply_functional, we create a new C++ function +// MulBackward0_apply_functional_ivalue that accepts vector. +// - We define how to pack and unpack arbitrary C++ types into IValues. +// - apply_with_saved passes MulBackward0_apply_functional_ivalue and +// the IValue arguments to Python via an indirection. +// In Python, these get proxy'ed into a graph. + +// Helper struct for packing/unpacking an arbitrary C++ type into a single +// IValue. There are various full and partial specializations for IValuePacker +// to handle packing specific types (like TensorOptions) into an IValue. +template +struct IValuePacker { + // Defines how to pack T into an IValue. + static at::IValue pack(const T& t) { + return t; + } + // Defines how to unpack an IValue into T. + static T unpack(const at::IValue& t) { + return t.to(); + } + // Returns the TypePtr for the IValue (this is like the "type" of the IValue). + // We use this when passing the packed IValue from Python to C++. + // In Python, the IValue is just a PyObject* with the native type. + // For example, it may be a Python int, a Python List[int], etc. + // When passing this PyObject* into C++, we need to know how to parse it + // into a C++ type that then gets put into an IValue. + // That's what the TypePtr is for: it contains the information to do the + // parsing. See torch::jit::toIValue for more information. + static at::TypePtr packed_type() { + if constexpr (::std::is_same_v) { + return at::TensorType::get(); + } else if constexpr (::std::is_same_v) { + return at::IntType::get(); + } else if constexpr (::std::is_same_v) { + return at::SymIntType::get(); + } else if constexpr (::std::is_same_v) { + return at::BoolType::get(); + } else if constexpr (::std::is_same_v) { + return at::FloatType::get(); + } else if constexpr (::std::is_same_v) { + return at::SymFloatType::get(); + } else if constexpr (::std::is_same_v) { + return at::SymBoolType::get(); + } else if constexpr (::std::is_same_v) { + return at::LayoutType::get(); + } else if constexpr (::std::is_same_v) { + return at::StringType::get(); + } else if constexpr (::std::is_same_v) { + return at::DeviceObjType::get(); + } else if constexpr (::std::is_same_v) { + return at::NumberType::get(); + } else if constexpr (::std::is_same_v) { + return at::MemoryFormatType::get(); + } else if constexpr (::std::is_same_v) { + return at::ScalarTypeType::get(); + } else { + // If you got here, you have probably added a member of a new type + // to a built-in C++ autograd node. + // Unfortunately, we don't know how to handle this type yet. + // To get this new type to work with Compiled Autograd, please + // either change it to be an IValue-constructible type, or + // define how to pack and unpack an object of this time into an IValue + // by creating a specialization of IValuePacker for this type. + // See NOTE: [Compiled Autograd and backward functions] for context. + TORCH_INTERNAL_ASSERT(false, "IValuePacker not implemented for type"); + return at::NoneType::get(); + } + } +}; + +template <> +struct IValuePacker { + static at::IValue pack(const size_t& t) { + // We generally use size_t as the size of a list of Tensors or number of + // dimensions. The number of dimensions generally do not exceed 64 + // (TensorIterator has that limitation), and lists of Tensors generally do + // not exceed the int64_t max (you'd probably run out of RAM or run into + // significant Tensor overhead). If you run into this limitation the fix is + // to figure out how to pack size_t into int64_t. Note that size_t has some + // weird behavior on Mac OS. + uint64_t maximum_value = std::numeric_limits::max(); + TORCH_INTERNAL_ASSERT( + static_cast(t) <= maximum_value, + "size_t too large to pack into IValue"); + return static_cast(t); // pack as int64_t + } + static size_t unpack(const at::IValue& t) { + return static_cast(t.toInt()); + } + static at::TypePtr packed_type() { + return IValuePacker::packed_type(); + } +}; + +template <> +struct IValuePacker> { + static at::IValue pack(const std::vector& t) { + return t; + } + static std::vector unpack(const at::IValue& t) { + // We need this because there's no t.to>() override? + return t.toSymIntVector(); + } + static at::TypePtr packed_type() { + return at::ListType::create(at::SymIntType::get()); + } +}; + +template <> +struct IValuePacker { + static at::IValue pack(const VariableInfo& t) { + auto tuple = std::make_tuple( + t.layout, t.device, t.scalar_type, t.size, t.requires_grad, t.is_empty); + return tuple; + } + static VariableInfo unpack(const at::IValue& t) { + auto tuple = t.toTuple(); + const auto& tuple_elements = tuple->elements(); + const auto elements = tuple_elements.asArrayRef(); + TORCH_INTERNAL_ASSERT(elements.size() == 6); + VariableInfo v; + v.layout = elements[0].toLayout(); + v.device = elements[1].toDevice(); + v.scalar_type = elements[2].toScalarType(); + v.size = elements[3].toSymIntVector(); + v.requires_grad = elements[4].toBool(); + v.is_empty = elements[5].toBool(); + return v; + } + static at::TypePtr packed_type() { + return at::TupleType::create({ + at::LayoutType::get(), + at::DeviceObjType::get(), + at::ScalarTypeType::get(), + at::ListType::create(at::SymIntType::get()), + at::BoolType::get(), + at::BoolType::get(), + }); + } +}; + +template <> +struct IValuePacker { + static at::IValue pack(const caffe2::TypeMeta& t) { + return at::typeMetaToScalarType(t); // pack as at::ScalarType + } + static caffe2::TypeMeta unpack(const at::IValue& t) { + return caffe2::TypeMeta::fromScalarType(t.to()); + } + static at::TypePtr packed_type() { + return IValuePacker::packed_type(); + } +}; + +inline std::optional optTypeMetaToScalarType( + const std::optional& t) { + if (t.has_value()) { + return at::typeMetaToScalarType(t.value()); + } else { + return std::nullopt; + } +} + +using packed_tensoroptions_t = std::tuple< + std::optional, + std::optional, + std::optional, + std::optional, + std::optional, + std::optional>; + +inline packed_tensoroptions_t pack_TensorOptions(const at::TensorOptions& t) { + auto tuple = std::make_tuple( + t.requires_grad_opt(), + t.memory_format_opt(), + t.device_opt(), + optTypeMetaToScalarType(t.dtype_opt()), + t.layout_opt(), + t.pinned_memory_opt()); + return tuple; +} +inline at::TensorOptions unpack_TensorOptions( + const packed_tensoroptions_t& tuple) { + at::TensorOptions result; + auto maybe_requires_grad = std::get<0>(tuple); + if (maybe_requires_grad.has_value()) { + result = result.requires_grad(maybe_requires_grad.value()); + } + auto maybe_memory_format = std::get<1>(tuple); + if (maybe_memory_format.has_value()) { + result = result.memory_format(maybe_memory_format.value()); + } + auto maybe_device = std::get<2>(tuple); + if (maybe_device.has_value()) { + result = result.device(maybe_device.value()); + } + auto maybe_dtype = std::get<3>(tuple); + if (maybe_dtype.has_value()) { + result = + result.dtype(caffe2::TypeMeta::fromScalarType(maybe_dtype.value())); + } + auto maybe_layout = std::get<4>(tuple); + if (maybe_layout.has_value()) { + result = result.layout(maybe_layout.value()); + } + auto maybe_pinned_memory = std::get<5>(tuple); + if (maybe_pinned_memory.has_value()) { + result = result.pinned_memory(maybe_pinned_memory.value()); + } + return result; +} + +template <> +struct IValuePacker { + static at::IValue pack(const at::TensorOptions& t) { + return pack_TensorOptions(t); + } + static at::TensorOptions unpack(const at::IValue& t) { + auto tuple = t.to(); + return unpack_TensorOptions(tuple); + } + static at::TypePtr packed_type() { + return at::TupleType::create( + {at::OptionalType::create(at::BoolType::get()), + at::OptionalType::create(at::MemoryFormatType::get()), + at::OptionalType::create(at::DeviceObjType::get()), + at::OptionalType::create(at::ScalarTypeType::get()), + at::OptionalType::create(at::LayoutType::get()), + at::OptionalType::create(at::BoolType::get())}); + } +}; + +template <> +struct IValuePacker { + static at::IValue pack(const TypeAndSize& t) { + auto tuple = std::make_tuple(t.sym_sizes, pack_TensorOptions(t.options)); + return tuple; + } + static TypeAndSize unpack(const at::IValue& t) { + auto tuple = + t.to, packed_tensoroptions_t>>(); + TypeAndSize result; + result.sym_sizes = std::get<0>(tuple); + result.options = unpack_TensorOptions(std::get<1>(tuple)); + return result; + } + static at::TypePtr packed_type() { + return at::TupleType::create( + {IValuePacker>::packed_type(), + IValuePacker::packed_type()}); + } +}; + +template +struct IValuePacker> { + static at::IValue pack(const std::optional& t) { + if (t.has_value()) { + return IValuePacker::pack(t.value()); + } else { + return std::nullopt; + } + } + static std::optional unpack(const at::IValue& t) { + if (t.isNone()) { + return std::nullopt; + } else { + return IValuePacker::unpack(t); + } + } + static at::TypePtr packed_type() { + return at::OptionalType::create(IValuePacker::packed_type()); + } +}; + +template +struct IValuePacker> { + static at::IValue pack(const std::vector& t) { + if constexpr (::std::is_constructible_v) { + return t; + } + if (t.empty()) { + auto lst = c10::impl::GenericList(at::AnyType::get()); + return lst; + } + auto type_ptr = IValuePacker::pack(t[0]).type(); + auto lst = c10::impl::GenericList(type_ptr); + for (const auto& elt : t) { + lst.emplace_back(IValuePacker::pack(elt)); + } + return lst; + } + static std::vector unpack(const at::IValue& t) { + if constexpr (::std::is_constructible_v) { + return t.to<::std::vector>(); + } + std::vector result; + auto lst = t.toList(); + for (const at::IValue& elt : lst) { + result.emplace_back(IValuePacker::unpack(elt)); + } + return result; + } + static at::TypePtr packed_type() { + return at::ListType::create(IValuePacker::packed_type()); + } +}; + +template +struct IValuePacker> { + static at::IValue pack(const c10::List& t) { + return IValuePacker>::pack(t.vec()); + } + static c10::List unpack(const at::IValue& t) { + return c10::List(IValuePacker>::unpack(t)); + } + static at::TypePtr packed_type() { + return IValuePacker>::packed_type(); + } +}; + +template +struct IValuePacker> { + static at::IValue pack(const std::array& t) { + std::vector result(t.begin(), t.end()); + return IValuePacker>::pack(result); + } + static std::array unpack(const at::IValue& t) { + std::array result; + auto packed = IValuePacker>::unpack(t); + for (size_t i = 0; i < packed.size(); i++) { + result[i] = packed[i]; + } + return result; + } + static at::TypePtr packed_type() { + return IValuePacker>::packed_type(); + } +}; + +template <> +struct IValuePacker { + static at::IValue pack(const at::TensorGeometry& t) { + auto tuple = std::make_tuple( + t.sym_sizes().vec(), t.sym_strides().vec(), t.sym_storage_offset()); + return tuple; + } + static at::TensorGeometry unpack(const at::IValue& t) { + auto tuple = t.to, + std::vector, + at::SymInt>>(); + return at::TensorGeometry( + std::get<0>(tuple), std::get<1>(tuple), std::get<2>(tuple)); + } + static at::TypePtr packed_type() { + return at::TupleType::create( + {IValuePacker>::packed_type(), + IValuePacker>::packed_type(), + at::SymIntType::get()}); + } +}; + +template <> +struct IValuePacker { + static at::IValue pack(const InputMetadata& t) { + TORCH_INTERNAL_ASSERT(!t.is_nested_tensor()); + auto tuple = std::make_tuple( + pack_TensorOptions(t.options()), + t.shape_as_dim_vector().vec(), + t.is_tensor_subclass()); + return tuple; + } + static InputMetadata unpack(const at::IValue& t) { + auto tuple = t.to< + std::tuple, bool>>(); + + return InputMetadata( + unpack_TensorOptions(std::get<0>(tuple)), + SymIntSmallVec(std::get<1>(tuple)), + std::get<2>(tuple), + false); + } + static at::TypePtr packed_type() { + return at::TupleType::create( + {IValuePacker::packed_type(), + IValuePacker>::packed_type(), + at::BoolType::get()}); + } +}; + +template +struct IValuePacker> { + static at::IValue pack(const at::OptionalArray& t) { + return IValuePacker>>::pack(t.list); + } + static at::OptionalArray unpack(const at::IValue& t) { + auto result = IValuePacker>>::unpack(t); + if (result.has_value()) { + return {result.value()}; + } else { + return {}; + } + } + static at::TypePtr packed_type() { + return IValuePacker>>::packed_type(); + } +}; + +// This is a helper struct for packing and unpacking multiple arguments into +// an ivalue_list. It leverages IValuePacker. +struct PackedArgs { + PackedArgs() = default; + + explicit PackedArgs(std::vector stack_) + : stack(std::move(stack_)) {} + + const std::vector& vec() const { + return stack; + } + + template + void pack(const T& t) { + stack.emplace_back(IValuePacker::pack(t)); + } + template + T unpack() { + return IValuePacker::unpack(std::move(stack[idx++])); + } + + void pack_saved_data(const ska::flat_hash_map& dct) { + std::vector keys; + std::vector values; + for (const auto& [key, value] : dct) { + keys.emplace_back(key); + values.emplace_back(value); + } + pack(keys); + for (const auto& value : values) { + pack(value); + } + } + + ska::flat_hash_map unpack_saved_data() { + ska::flat_hash_map dct; + auto keys = unpack>(); + for (const auto& key : keys) { + dct.insert({key, std::move(stack[idx++])}); + } + return dct; + } + + private: + std::vector stack; + int64_t idx = 0; +}; + +// This is a layer of indirection for calling methods on the Python +// AutogradCompilerInstance (referred to as the "py_compiler") from +// libtorch_cpu (where Python is not available). +// A PyCompilerInterfaceImpl in libtorch_python subclasses it and +// overrides the methods to do the actual calls back to Python. +struct TORCH_API PyCompilerInterface { + PyCompilerInterface() = default; + PyCompilerInterface(const PyCompilerInterface&) = delete; + PyCompilerInterface& operator=(const PyCompilerInterface&) = delete; + PyCompilerInterface(PyCompilerInterface&&) = delete; + PyCompilerInterface& operator=(PyCompilerInterface&&) = delete; + virtual ~PyCompilerInterface() = default; + + // Invokes py_compiler.bind_function(fn_name, fn) + virtual std::string bind_function( + PyObject* py_compiler, + const std::string& fn_name, + // NOLINTNEXTLINE(performance-unnecessary-value-param) + functional_apply_t fn, + // NOLINTNEXTLINE(performance-unnecessary-value-param) + std::vector packed_args_schema, + bool is_custom_function = false) { + TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); + } + + // Invokes py_compiler.method_name(fn_name, inputs, packed_args, + // output_metadata) + virtual variable_list call_function( + PyObject* py_compiler, + const char* method_name, + const std::string& fn_name, + const variable_list& inputs, + const ivalue_list& packed_args, + const c10::IValue& output_metadata) { + TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); + } + + virtual variable_list call_copy_slices_prologue( + PyObject* py_compiler, + const variable_list& inputs, + const at::TensorGeometry& base, + const at::TensorGeometry& view) { + TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); + } + virtual variable_list call_copy_slices_epilogue( + PyObject* py_compiler, + const std::vector& needs_input_grad, + const at::Tensor& result, + const variable_list& res, + const at::Tensor& grad_slice) { + TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); + } +}; + +TORCH_API const std::unique_ptr& getPyCompilerInterface(); +TORCH_API void setPyCompilerInterface( + std::unique_ptr&& impl); +TORCH_API void resetPyCompilerInterface(); + +// including torch/csrc/autograd/engine.h breaks BC by somehow introducing +// symbol resolution issues. Instead requiring downstream users to include +// engine.h to access collect_input_metadata, we provide it here (with a +// different name to avoid ambigous symbols...) +TORCH_API std::vector> get_input_metadata( + const edge_list& edges); + } // namespace torch::dynamo::autograd template <> diff --git a/torch/csrc/dynamo/python_compiled_autograd.cpp b/torch/csrc/dynamo/python_compiled_autograd.cpp index b58f62c3d2fcf..ccdf65e887d9d 100644 --- a/torch/csrc/dynamo/python_compiled_autograd.cpp +++ b/torch/csrc/dynamo/python_compiled_autograd.cpp @@ -52,6 +52,156 @@ at trace time. namespace torch::dynamo::autograd { using c10::SymInt; +// List[Optional[Tensor]] in Python can't be directly parsed into a +// List[Tensor], so we need to do this conversion manually. +static std::vector toTensorList( + const std::vector>& inputs) { + std::vector result; + result.reserve(inputs.size()); + for (const auto& inp : inputs) { + if (inp.has_value()) { + result.emplace_back(*inp); + } else { + result.emplace_back(); + } + } + return result; +} + +// Binds a function (that represents some backward computation) to Python. +// All of these functions have a common signature, which is +// (in C++) (vector, vector) -> vector +// (in Python) (List[Optional[Tensor]], *packed_args: IValue) -> +// List[Optional[Tensor]] +// +// The vector are the list of gradient Tensors, each of which may be +// undefined (in C++) which corresponds to None (in Python). +static std::string bind_function( + PyObject* py_compiler, + const std::string& fn_name, + functional_apply_t fn, + std::vector packed_args_schema, + bool is_custom_function) { + // This is the function that can be called from Python. + auto py_func = py::cpp_function( + [packed_args_schema = std::move(packed_args_schema), fn = std::move(fn)]( + std::vector>& inputs, + const py::args& py_args) -> py::object { + // py_args is a tuple of PyObject*. + // We need to reconstruct a vector to invoke `fn`. + // To do so, we use the packed_args_schema to convert each PyObject* + // to its corresponding C++ type that can be stored into IValue. + TORCH_INTERNAL_ASSERT(py_args.size() == packed_args_schema.size()); + std::vector args; + args.reserve(py_args.size()); + auto tuple_args = jit::tuple_slice(py_args); + for (uint64_t idx = 0; idx < packed_args_schema.size(); idx++) { + if (packed_args_schema[idx]->isSubtypeOf( + *at::ListType::ofTensors())) { + // List[Tensor] might have Nones, not handled in jit::toIValue + auto tmp = py::cast>>( + tuple_args[idx]); + args.emplace_back(toTensorList(tmp)); + } else { + args.emplace_back(jit::toIValue( + tuple_args[idx], packed_args_schema[idx], std::nullopt)); + } + } + // None in Python corresponds to undefined Tensor in C++ + auto inputs_ = toTensorList(inputs); + auto outputs = fn(inputs_, args); + return jit::toPyObject(at::IValue(outputs)); + }); + py::handle handle(py_compiler); + auto result = + handle.attr("bind_function")(fn_name, py_func, is_custom_function); + return result.cast(); +} + +// Invokes py_compiler.method_name(fn_name, inputs, packed_args, +// output_metadata) +static variable_list call_function( + PyObject* py_compiler, + const char* method_name, + const std::string& fn_name, + const variable_list& inputs, + const ivalue_list& packed_args, + const c10::IValue& output_metadata) { + // convert ivalue_list -> PyObject* + PyObject* py_packed_args = + PyTuple_New(static_cast(packed_args.size())); + for (const auto i : c10::irange(packed_args.size())) { + py::object obj = jit::toPyObject(packed_args[i]); + Py_INCREF(obj.ptr()); + PyTuple_SET_ITEM(py_packed_args, i, obj.ptr()); + } + + // call the corresponding method on the py_compiler + py::handle handle(py_compiler); + py::object stuff = handle.attr(method_name)( + fn_name, + inputs, + py::handle(py_packed_args), + jit::toPyObject(output_metadata)); + + // Convert the output from PyObject* to vector + auto tmp = py::cast>>(stuff); + return toTensorList(tmp); +} + +struct PyCompilerInterfaceImpl : PyCompilerInterface { + std::string bind_function( + PyObject* py_compiler, + const std::string& fn_name, + functional_apply_t fn, + std::vector packed_args_schema, + bool is_custom_function = false) override { + return torch::dynamo::autograd::bind_function( + py_compiler, + fn_name, + std::move(fn), + std::move(packed_args_schema), + is_custom_function); + } + variable_list call_function( + PyObject* py_compiler, + const char* method_name, + const std::string& fn_name, + const variable_list& inputs, + const ivalue_list& packed_args, + const c10::IValue& output_metadata) override { + return torch::dynamo::autograd::call_function( + py_compiler, + method_name, + fn_name, + inputs, + packed_args, + output_metadata); + } + variable_list call_copy_slices_prologue( + PyObject* py_compiler, + const variable_list& inputs, + const at::TensorGeometry& base, + const at::TensorGeometry& view) override { + py::handle handle(py_compiler); + py::object stuff = + handle.attr("call_copy_slices_prologue")(inputs, base, view); + return py::cast>(stuff); + } + variable_list call_copy_slices_epilogue( + PyObject* py_compiler, + const std::vector& needs_input_grad, + const at::Tensor& result, + const variable_list& res, + const at::Tensor& grad_slice) override { + py::handle handle(py_compiler); + py::object stuff = handle.attr("call_copy_slices_epilogue")( + needs_input_grad, result, res, grad_slice); + auto output = py::cast>>(stuff); + return toTensorList(output); + } +}; + static PyObject* wrap_int_list(const std::vector& inputs) { PyObject* pyinput = PyTuple_New(static_cast(inputs.size())); for (const auto i : c10::irange(inputs.size())) { @@ -88,6 +238,22 @@ static void check(bool result) { check(nullptr); } +static variable_list validate_outputs( + const variable_list& outputs, + const ivalue_list& args) { + auto r = PackedArgs(args); + auto value = r.unpack>>(); + auto new_outputs = outputs; + + torch::autograd::validate_outputs( + value, new_outputs, [&](const std::string& msg) { + std::ostringstream ss; + ss << "[Compiled Autograd Tracing:]" << msg; + return ss.str(); + }); + return new_outputs; +} + // snapshot of python verbose logging toggle static PyObject* python_verbose_logger = nullptr; @@ -498,6 +664,21 @@ static void set_ivalue_proxies( } } +static at::Tensor call_accumulate( + PyObject* py_compiler, + const at::Tensor& old_var, + const at::Tensor& new_var) { + if (!old_var.defined()) { + return new_var; + } + if (!new_var.defined()) { + return old_var; + } + py::handle handle(py_compiler); + py::object stuff = handle.attr("accumulate")(old_var, new_var); + return py::cast(stuff); +} + static TraceState call_begin_capture( PyObject* self, CacheNode& cache, @@ -657,6 +838,8 @@ static CacheNode* _compiled_autograd_impl( ClosingTHPObjectPtr py_compiler( check(PyObject_CallNoArgs((the_autograd_compiler)))); + setPyCompilerInterface(std::make_unique()); + TraceState state = call_begin_capture( py_compiler, *cache, compiler_call, output_edges.size()); InputBuffers input_buffers; @@ -723,16 +906,52 @@ static CacheNode* _compiled_autograd_impl( SwapSavedVariables saved(compiler_call, state, py_compiler.get(), call); variable_list outputs = call.node->apply_with_saved(inputs, saved); - saved.debug_asserts(); saved.before(call.node->next_edges()); - validate_outputs( - call.node->next_edges(), outputs, [&](const std::string& msg) { - std::ostringstream ss; - ss << "[Compiled Autograd Tracing: " << call.node->name() << "] " - << msg; - return ss.str(); - }); + + auto input_metadata = get_input_metadata(call.node->next_edges()); + TORCH_INTERNAL_ASSERT(input_metadata.size() == outputs.size()); + + // Lazily bind the `validate_outputs` function to Python. + static c10::once_flag flag; + c10::call_once(flag, [&]() { + auto schema = std::vector{IValuePacker< + std::vector>>::packed_type()}; + bind_function( + py_compiler.get(), + "validate_outputs", + validate_outputs, + schema, + false); + }); + + // Don't emit validate_outputs nodes that follow a CompiledBackward node. + // These nodes would otherwise prevent reordering of accumulate_grad + // nodes. + // + // Note that this will not cause correctness issues, because + // 1) AOTAutograd already coerces gradients to have the same metadata as + // the inputs. 2) the AOTAutograd graph already has the necessary + // aten::sum_to nodes in it (so it doesn't need to rely on + // validate_outputs to handle that). + // + // However, we may be dropping some (edge case) safety checks compared to + // eager: a backward that would have errored out in eager may not error + // out in compiled autograd (for example, if the user provided an + // incorrect number of gradients). + if (!call.node->is_aot_backward()) { + PackedArgs args; + args.pack(input_metadata); + ivalue_list input_metadata_state = std::move(args).vec(); + outputs = call_function( + py_compiler, + "validate_outputs", + "validate_outputs", + outputs, + input_metadata_state, + input_metadata_state[0]); + } + saved.after(call.node->next_edges()); saved.debug_asserts(); @@ -754,13 +973,14 @@ static CacheNode* _compiled_autograd_impl( auto& output = outputs[i]; const auto& next = call.node->next_edge(i); if (next.is_valid() && output.defined()) { - input_buffers.lookup(next.function.get()) - .add( - next.input_nr, std::move(output), std::nullopt, std::nullopt); + auto& buffer = input_buffers.lookup(next.function.get()); + buffer.buffer[next.input_nr] = call_accumulate( + py_compiler, buffer.buffer[next.input_nr], output); } } } + resetPyCompilerInterface(); PyObject* res = check(call_end_capture(py_compiler, state.outputs)); TORCH_CHECK(PyTuple_Check(res), "Expected end_capture to return tuple"); TORCH_CHECK(