diff --git a/aten/src/ATen/TensorGeometry.h b/aten/src/ATen/TensorGeometry.h index 41f14a15ba99c..06a064063c4e2 100644 --- a/aten/src/ATen/TensorGeometry.h +++ b/aten/src/ATen/TensorGeometry.h @@ -37,6 +37,16 @@ struct TORCH_API TensorGeometry { has_symbolic_sizes_strides_( t.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {} + explicit TensorGeometry( + std::vector sizes, + std::vector strides, + at::SymInt storage_offset) + : sizes_(std::move(sizes)), + strides_(std::move(strides)), + storage_offset_(std::move(storage_offset)) { + recompute(); + } + // true if the tensor is contiguous bool is_contiguous() const; diff --git a/build_variables.bzl b/build_variables.bzl index 8bd8ad3a8df0a..a95c03cd0b345 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -138,6 +138,7 @@ core_trainer_sources = [ "torch/csrc/autograd/variable.cpp", "torch/csrc/autograd/utils/warnings.cpp", "torch/csrc/autograd/jit_decomp_interface.cpp", + "torch/csrc/dynamo/compiled_autograd.cpp", "torch/csrc/jit/frontend/name_mangler.cpp", "torch/csrc/jit/ir/type_hashing.cpp", "torch/csrc/jit/serialization/pickler.cpp", diff --git a/test/dynamo/test_backward_higher_order_ops.py b/test/dynamo/test_backward_higher_order_ops.py index 14e3f2e044c10..2aa5ee3e71894 100644 --- a/test/dynamo/test_backward_higher_order_ops.py +++ b/test/dynamo/test_backward_higher_order_ops.py @@ -121,23 +121,30 @@ def fn(x, y): out.backward(grad_out) actual = normalize_gm(graph.print_readable(False)) self.assertEqual(x.grad, grad_out * grad_out) - self.assertExpectedInline( - actual, - """\ + if backend in ["aot_eager", "inductor"]: + self.assertExpectedInline( + actual, + """\ class GraphModule(torch.nn.Module): def forward(self, L_inputs_ : list): l_inputs_ = L_inputs_ - getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None + getitem: "f32[2]" = l_inputs_[0]; l_inputs_ = None - new_grad: "f32[s0]" = torch.clone(getitem) + validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [2], False)]); getitem = None + getitem_3: "f32[2]" = validate_outputs[0]; validate_outputs = None - result: "f32[s0]" = getitem * getitem; getitem = None + call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_3); getitem_3 = None + getitem_5: "f32[2]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None - new_grad_1: "f32[s0]" = torch.clone(result); result = None + new_grad: "f32[2]" = torch.clone(getitem_5) + + result: "f32[2]" = getitem_5 * getitem_5; getitem_5 = None + + new_grad_1: "f32[2]" = torch.clone(result); result = None return (new_grad, new_grad_1) """, - ) + ) graph = None @@ -162,7 +169,7 @@ def inner_compiler(gm_, example_inputs_): gm, backend=inner_compiler, fullgraph=True, dynamic=True ) - for backend in ["eager", "aot_eager", "inductor"]: + for backend in ["inductor"]: torch._dynamo.reset() x = torch.tensor([0.5, 0.5], requires_grad=True) y = torch.tensor([0.5, 0.5], requires_grad=True) @@ -187,26 +194,33 @@ def fn(x, y): actual = normalize_gm(graph.print_readable(False)) self.assertEqual(obj.counter, 1) self.assertEqual(x.grad, grad_out + grad_out) - self.assertExpectedInline( - actual, - """\ + if backend in ["aot_eager", "inductor"]: + self.assertExpectedInline( + actual, + """\ class GraphModule(torch.nn.Module): - def forward(self, L_inputs_ : list, L_hooks_0_keywords_fn_keywords_obj_counter: "Sym(s1)"): + def forward(self, L_inputs_ : list, L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s1)"): l_inputs_ = L_inputs_ - l_hooks_0_keywords_fn_keywords_obj_counter = L_hooks_0_keywords_fn_keywords_obj_counter + l_hooks_1_keywords_fn_keywords_obj_counter = L_hooks_1_keywords_fn_keywords_obj_counter - getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None + getitem: "f32[2]" = l_inputs_[0]; l_inputs_ = None - new_grad: "f32[s0]" = torch.clone(getitem) + validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [2], False)]); getitem = None + getitem_3: "f32[2]" = validate_outputs[0]; validate_outputs = None - add: "Sym(s1 + 1)" = l_hooks_0_keywords_fn_keywords_obj_counter + 1; l_hooks_0_keywords_fn_keywords_obj_counter = None + call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_3); getitem_3 = None + getitem_5: "f32[2]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None - result: "f32[s0]" = getitem * getitem; getitem = None + new_grad: "f32[2]" = torch.clone(getitem_5) - new_grad_1: "f32[s0]" = torch.clone(result); result = None + add: "Sym(s1 + 1)" = l_hooks_1_keywords_fn_keywords_obj_counter + 1; l_hooks_1_keywords_fn_keywords_obj_counter = None + + result: "f32[2]" = getitem_5 * getitem_5; getitem_5 = None + + new_grad_1: "f32[2]" = torch.clone(result); result = None return (new_grad, new_grad_1, add) """, - ) + ) out = fn(x, y) out.backward(grad_out) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index e08b19b3f2b7f..87fceb9bf624e 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -22,6 +22,7 @@ from torch._dynamo import compiled_autograd, config from torch._dynamo.backends.debugging import aot_eager from torch._dynamo.device_interface import get_interface_for_device +from torch._dynamo.testing import normalize_gm from torch._dynamo.utils import counters from torch._inductor import config as inductor_config from torch._inductor.test_case import run_tests, TestCase @@ -2821,8 +2822,11 @@ def test_cudagraphs_cpu_scalar_used_in_cpp_custom_op(self, load_inline): opt_bwd() self.assertEqual(counters["compiled_autograd"]["captures"], 1) - # always safe to move, since we trace into the autograd::function bwd and can see if it's only used by aten ops - self.assertEqual(counters["inductor"]["cudagraph_skips"], 0) + # Compiled autograd's initial capture lifts custom C++ autograd::Function bwd instead of tracing + # into it. We must skip since we do not know if the cpu scalar will be used only in ATen/prim ops. + # In the future, we can consider having a cpu scalar movement pass sometime after we trace + # into the custom C++ autograd::Function (like in AOTDispatcher) + self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) def test_logs(self): logs, ctx = logs_to_string( @@ -2941,12 +2945,11 @@ def forward(model, x): expected_logs = [ "code: CompiledFunctionBackward (NodeCall 2)", + "code: CompiledFunctionBackward0 (NodeCall 2)", "aot0_primals_3", "aot0_relu", "aot0_le", "aot0_permute_2", - "code: CompiledFunctionBackward0 (NodeCall 2)", - "aot0_tangents_1", "aot0_full_default", "aot0_where", "aot0_mm", @@ -2996,20 +2999,17 @@ def f(x): expected_logs = [ "CompiledFunctionBackward1", - "aot1_tangents_1", "aot1_sin_1", - "aot1_primals_2", "aot1_neg", "aot0_tangents_2", "aot1_cos_1", - "aot1_primals_1", "aot0_tangents_1", "CompiledFunctionBackward0", + "aot0_sin_1", "aot0_neg", - "aot0_sin", "aot0_mul", + "aot0_cos_1", "aot0_mul_1", - "aot0_cos", "aot0_add", ] @@ -3154,6 +3154,120 @@ def fn(): self.assertEqual(sum(1 for e in unexpected_logs if e in logs.getvalue()), 0) + def test_tensor_subclass_basic(self): + from torch.testing._internal.two_tensor import TwoTensor, TwoTensorMode + + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + lib.define("to_twotensor(Tensor a, Tensor b) -> Tensor") + lib.define("from_twotensor(Tensor c) -> (Tensor, Tensor)") + + def to_twotensor_backward(ctx, grad): + return torch.ops.mylib.from_twotensor(grad) + + def from_twotensor_backward(ctx, grad_a, grad_b): + raise AssertionError("shouldn't get hit") + + torch.library.register_autograd( + "mylib::to_twotensor", to_twotensor_backward, lib=lib + ) + torch.library.register_autograd( + "mylib::from_twotensor", from_twotensor_backward, lib=lib + ) + + @torch.library.register_torch_dispatch( + "mylib::to_twotensor", TwoTensorMode, lib=lib + ) + def _(_0, _1, _2, args, kwargs): + assert not kwargs + a, b = args + return TwoTensor(a.clone(), b.clone()) + + @torch.library.register_torch_dispatch( + "mylib::from_twotensor", TwoTensor, lib=lib + ) + def _(_0, _1, _2, args, kwargs): + assert not kwargs + (c,) = args + return c.a.clone(), c.b.clone() + + @torch.compile(backend="aot_eager", fullgraph=True) + def fn(x): + return x * x + 2 + + param1 = torch.randn(4, 4, requires_grad=True) + param2 = torch.randn(4, 4, requires_grad=True) + with TwoTensorMode(): + x = torch.ops.mylib.to_twotensor(param1, param2) + + inner_compiler_fn = make_compiler_fn(fullgraph=True, backend="aot_eager") + graphs = [] + + def compiler_fn(gm): + graphs.append(gm) + return inner_compiler_fn(gm) + + with compiled_autograd._enable(compiler_fn): + res = fn(x) + res.sum().backward() + + self.assertEqual(param1.grad, 2 * param1) + self.assertEqual(param2.grad, 2 * param2) + self.assertEqual(len(graphs), 1) + + graph_code = normalize_gm(graphs[0].print_readable(print_output=False)) + # The graph should have make_subclass calls in it. + self.assertExpectedInline( + graph_code, + """\ +class CompiledAutograd0(torch.nn.Module): + def forward(self, inputs, sizes, scalars, hooks): + getitem = inputs[0] + getitem_1 = inputs[1] + getitem_2 = inputs[2] + getitem_3 = inputs[3] + getitem_4 = inputs[4]; inputs = None + + validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], True)]); getitem = None + getitem_5 = validate_outputs[0]; validate_outputs = None + + sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_5], [True], [4, 4]); getitem_5 = None + getitem_6 = sum_backward0[0]; sum_backward0 = None + validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_6], [((None, None, device(type='cpu'), 6, 0, None), [4, 4], True)]); getitem_6 = None + getitem_7 = validate_outputs_1[0]; validate_outputs_1 = None + + getitem_8 = hooks[0]; getitem_8 = None + call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((getitem_1, getitem_2), [], getitem_7); getitem_1 = getitem_2 = getitem_7 = None + aot0_primals_1 = call_aot_bwd_prologue[0] + aot0_primals_2 = call_aot_bwd_prologue[1] + aot0_tangents_1 = call_aot_bwd_prologue[2] + aot0_tangents_2 = call_aot_bwd_prologue[3]; call_aot_bwd_prologue = None + + aot0_mul_2 = torch.ops.aten.mul.Tensor(aot0_tangents_1, aot0_primals_1); aot0_tangents_1 = aot0_primals_1 = None + aot0_mul_3 = torch.ops.aten.mul.Tensor(aot0_tangents_2, aot0_primals_2); aot0_tangents_2 = aot0_primals_2 = None + + aot0_add_2 = torch.ops.aten.add.Tensor(aot0_mul_2, aot0_mul_2); aot0_mul_2 = None + aot0_add_3 = torch.ops.aten.add.Tensor(aot0_mul_3, aot0_mul_3); aot0_mul_3 = None + + make_subclass = torch__dynamo_compiled_autograd_make_subclass(aot0_add_2, aot0_add_3); aot0_add_2 = aot0_add_3 = None + + getitem_13 = hooks[1]; hooks = None + call_backward = torch__dynamo_external_utils_call_backward(getitem_13, (), make_subclass); getitem_13 = make_subclass = None + getitem_16 = call_backward[0] + getitem_17 = call_backward[1]; call_backward = None + validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_16, getitem_17], [((None, None, device(type='cpu'), 6, 0, None), [4, 4], False), ((None, None, device(type='cpu'), 6, 0, None), [4, 4], False)]); getitem_16 = getitem_17 = None + getitem_19 = validate_outputs_2[0] + + accumulate_grad__1 = torch.ops.inductor.accumulate_grad_.default(getitem_4, getitem_19); getitem_4 = getitem_19 = accumulate_grad__1 = None + + getitem_20 = validate_outputs_2[1]; validate_outputs_2 = None + + accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_3, getitem_20); getitem_3 = getitem_20 = accumulate_grad_ = None + + _exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None + return [] +""", # noqa: B950 + ) + # https://github.com/pytorch/pytorch/issues/138920 def test_compiled_autograd_does_not_specialize_on_bw_symints(self): class Mod(torch.nn.Module): @@ -3247,7 +3361,7 @@ def inner_compiler(gm_, example_inputs_): # because we ignore all of these guards anyway in CA. # Once we stop using make_fx in CA, we won't have to worry about this specialization. view_nodes = graphs[1].graph.find_nodes( - op="call_function", target=torch.ops.aten.view.default + op="call_function", target=torch.ops.aten.reshape.default ) # First 2 view nodes have a first argument that is a SymInt, not an int burned into the graph self.assertTrue(isinstance(view_nodes[0].args[1][0], torch.fx.Node)) @@ -3640,6 +3754,7 @@ def wrap_test_class(orig_cls): "test_tp_compile_comm_reordering", "test_unwrap_async_collective_tensor_tangent", # Uncategorized + "test_not_implemented_grad", # Dynamo changes the types of exceptions } if not HAS_CUDA: diff --git a/test/inductor/test_distributed_patterns.py b/test/inductor/test_distributed_patterns.py index 9d1a3202f1475..ed50a02d7b274 100644 --- a/test/inductor/test_distributed_patterns.py +++ b/test/inductor/test_distributed_patterns.py @@ -337,7 +337,9 @@ def test_module_backward_hooks_eager(self): self.assertEqual(fw_cnt.frame_count, 1) self.assertEqual(fw_cnt.op_count, 5) self.assertEqual(bw_cnt.frame_count, 2) # grad=None and grad!=None - self.assertEqual(bw_cnt.op_count, 48) + self.assertEqual( + bw_cnt.op_count, 72 + ) # Number of ops in the Dynamo-produced graphs def test_module_backward_hooks_aot(self): m1, inp1 = init_module_bw_hooks(True) diff --git a/tools/autograd/gen_autograd_functions.py b/tools/autograd/gen_autograd_functions.py index fa1d0ce4bc910..176fa68c53b65 100644 --- a/tools/autograd/gen_autograd_functions.py +++ b/tools/autograd/gen_autograd_functions.py @@ -107,6 +107,17 @@ ${body} return grad_inputs; } +inline variable_list ${op}_apply_functional_ivalue(const variable_list& grads, const ivalue_list& args) +{ +#ifdef C10_MOBILE + TORCH_INTERNAL_ASSERT(false, "compiled autograd doesn't work on mobile"); +#else + auto packed_args = PackedArgs(args); + auto needs_input_grad = packed_args.unpack>(); + ${unpack_ivalues} + return ${op}_apply_functional(variable_list(grads), needs_input_grad${,apply_functional_args}); +#endif +} variable_list ${op}::apply(variable_list&& grads) { ${thread_lock} @@ -120,11 +131,35 @@ ${compiled_args} } variable_list ${op}::apply_with_saved(const variable_list& grads, SwapSavedVariables& saved) { - ${apply_with_saved_before} - variable_list result = apply(variable_list(grads)); - ${apply_with_saved_after} - return result; +#ifdef C10_MOBILE + TORCH_INTERNAL_ASSERT(false, "compiled autograd doesn't work on mobile"); +#else + ${apply_with_saved_before} + + static bool called = false; + if (!called) { + called = true; + ${compute_schema} + const auto& pyinterface = torch::dynamo::autograd::getPyCompilerInterface(); + pyinterface->bind_function(saved.get_py_compiler(), name(), ${op}_apply_functional_ivalue, schema); + } + + variable_list output_result; + + PackedArgs packed_args; + ${asserts} + ${unpacks} + ${compute_needs_input_grad} + packed_args.pack(needs_input_grad); + ${get_packed_args} + + output_result = compiled_autograd_apply_functional(packed_args, next_edges(), saved, grads, name()); + + ${apply_with_saved_after} + return output_result; +#endif } + """ ) @@ -993,14 +1028,38 @@ def emit_derivative( f"{T} {x}" for T, x in zip(apply_functional_args_ref_types, apply_functional_args) ] + get_packed_args = "\n".join( + f"packed_args.pack({name});" for name in apply_functional_args + ) + unpack_ivalues = [] + for typ, name in zip(apply_functional_args_ref_types, apply_functional_args): + if typ.endswith("&"): + typ = typ[:-1] + unpack_ivalues.append(f"auto {name} = packed_args.unpack<{typ}>();") + + schema_args = [f"std::array"] + for typ in apply_functional_args_ref_types: + if typ.endswith("&"): + typ = typ[:-1] + if typ.startswith("const"): + typ = typ[5:] + schema_args.append(typ.strip()) + compute_schema = ["std::vector schema = {"] + for schema_arg in schema_args: + compute_schema.append( + f" torch::dynamo::autograd::IValuePacker<{schema_arg}>::packed_type()," + ) + compute_schema.append("};") return template.substitute( unpacks="\n".join(unpack), op=info.op, + compute_schema="\n".join(compute_schema), apply_functional_args=apply_functional_args, apply_functional_args_signature=apply_functional_args_signature, compute_needs_input_grad=compute_needs_input_grad, num_inputs=len(input_name_to_idx), + unpack_ivalues="\n".join(unpack_ivalues), compute_index_ranges=compute_index_ranges, saved_variables=saved_variables, release_variables=release_variables, @@ -1015,4 +1074,5 @@ def emit_derivative( compiled_args=compiled_args, apply_with_saved_before=apply_with_saved_before, apply_with_saved_after=apply_with_saved_after, + get_packed_args=get_packed_args, ) diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp index 5bc089f67df74..ba5cb3d912c5d 100644 --- a/tools/autograd/templates/Functions.cpp +++ b/tools/autograd/templates/Functions.cpp @@ -15,6 +15,30 @@ using at::TensorList; namespace torch::autograd::generated { +static at::IValue compute_output_metadata(const torch::autograd::edge_list& next_edges) { + auto output_metadata = torch::dynamo::autograd::IValuePacker< + std::vector>>::pack( + torch::dynamo::autograd::get_input_metadata(next_edges)); + return output_metadata; +} + +static C10_NOINLINE variable_list compiled_autograd_apply_functional( + const PackedArgs& packed_args, + const edge_list& next_edges, + SwapSavedVariables& saved, + const variable_list& grads, + const std::string& name) { + auto output_metadata = compute_output_metadata(next_edges); + const auto& pyinterface = torch::dynamo::autograd::getPyCompilerInterface(); + return pyinterface->call_function( + saved.get_py_compiler(), + "apply_functional", + name, + grads, + packed_args.vec(), + output_metadata); +} + ${autograd_function_definitions} } // namespace torch::autograd::generated diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index fb7017bc6dc9d..0627789a4bf1d 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -4,10 +4,11 @@ import itertools import operator import time -from collections import defaultdict +from collections import Counter, defaultdict from typing import Any, Optional, TYPE_CHECKING, Union import torch +import torch.utils._pytree as pytree from torch._dynamo.external_utils import ( call_backward, call_hook, @@ -65,6 +66,50 @@ def maybe_clone(x): return x +# We lazily bind "functional backward" variants for PyTorch built-in autograd +# nodes to this class. Example: torch._dynamo.compiled_autograd.ops.MulBackward0 +# Each "functional backward" is bound the first time the node's apply_with_saved +# function is called. It's possible to avoid lazy binding and instead bind +# all of this upfront (perhaps at import time) via codegen changes. +class OpNamespace: + def __init__(self): + self.custom_function_name_counter: Counter[str] = Counter() + + def add(self, name, fn, is_custom_function=False): + if is_custom_function: + name = "CppNode" + name + count = self.custom_function_name_counter[name] + self.custom_function_name_counter[name] += 1 + name = f"{name}{count}" + else: + assert not hasattr(self, name) + + result = Op(name, fn, is_custom_function) + torch._dynamo.allow_in_graph(result) + setattr(self, name, result) + return name + + def get(self, name): + return getattr(self, name) + + +class Op: + def __init__(self, name, fn, is_custom_function): + self.fn = fn + self.is_custom_function = is_custom_function + self.__name__ = name + self.__module__ = "torch._dynamo.compiled_autograd.ops" + + def __call__(self, *args, **kwargs): + return self.fn(*args, **kwargs) + + def __repr__(self): + return self.__module__ + "." + self.__name__ + + +ops = OpNamespace() + + _graph_placeholders = ["inputs", "sizes", "scalars", "hooks"] _impure_targets = OrderedSet( [ @@ -137,7 +182,8 @@ def begin_capture( self.fx_tracer.root = torch.nn.Module() self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer) self.fx_tracer.tensor_attrs = {} - args_proxy, sizes_proxy, scalars_proxy, self.hooks_proxy = ( + self.symnode_proxy_lookup = {} + args_proxy, self.sizes_proxy, self.scalars_proxy, self.hooks_proxy = ( self.fx_tracer.create_proxy("placeholder", name, (), {}) for name in _graph_placeholders ) @@ -160,7 +206,9 @@ def begin_capture( ) for idx, val in enumerate(sizes) ] - self.bind_tensors_to_proxies(sizes, sizes_proxy, sizes_origins) + self.bind_tensors_to_proxies(sizes, self.sizes_proxy, sizes_origins) + for i, symint in enumerate(sizes): + self.symnode_proxy_lookup[symint.node] = self.sizes_proxy[i] for idx, val in enumerate(scalars): source = self.source("scalars", idx) @@ -182,7 +230,9 @@ def begin_capture( ) else: raise AssertionError("Unexpected scalar type: ", type(val)) - self.bind_tensors_to_proxies(scalars, scalars_proxy, scalars_origins) + self.bind_tensors_to_proxies(scalars, self.scalars_proxy, scalars_origins) + for i, symval in enumerate(scalars): + self.symnode_proxy_lookup[symval.node] = self.scalars_proxy[i] # type: ignore[union-attr] # TODO(jansel): are all these modes needed? self.stack.enter_context(decompose({})) @@ -197,31 +247,203 @@ def begin_capture( ) return inputs, sizes, scalars - def proxy_call_backward( + def proxy_call_aot_backward( self, - inputs, - output_metadatas, + pinputs, + psaved_tensors, saved_tensors, - backward_idx: int, + pctx, + ctx, + maybe_backward_state_idx, ): - assert self.hooks_proxy is not None - backward_c_function = self.hooks_proxy[backward_idx] # type: ignore[index] - proxies = self.fx_tracer.create_proxy( + # The AOTBackward call consists of three things: the prologue, the + # backward graph, and the epilogue. + # Our strategy is: + # - allow_in_graph the prologue (in the CA graph and Dynamo graph), + # - copy-paste the backward graph into the CA graph so that CA passes and Dynamo can see it + # - trace directly through the epilogue. Anything that gets baked in is + # constant metadata (for example, metadata about the number of outputs, or removing + # RNG arguments or effect tokens). + # If Dynamo graph capture were better, then we could add a node for the prologue + # into the CA graph and have Dynamo trace into it. + + psymints = [self.to_proxy(e) for e in ctx._get_compiled_autograd_symints()] + + # NOTE: we should only close over constants + CompiledFunction = ctx._forward_cls + metadata = CompiledFunction.metadata + maybe_subclass_metadata = CompiledFunction.maybe_subclass_metadata + del CompiledFunction + + @torch._dynamo.allow_in_graph # type: ignore[misc] + def call_aot_bwd_prologue(ctx_saved_tensors, ctx_symints, *flat_args): + out = torch._functorch._aot_autograd.runtime_wrappers._backward_prologue_functional( + ctx_saved_tensors, + ctx_symints, + metadata, + maybe_subclass_metadata, + *flat_args, + ) + return out + + pgrads = self.fx_tracer.create_proxy( kind="call_function", - target=call_backward, + target=call_aot_bwd_prologue, args=( - backward_c_function, - self.to_proxy(saved_tensors), - *self.to_proxy(inputs), + psaved_tensors, + psymints, + *pinputs, ), kwargs={}, ) + pbackward_state = None + if maybe_backward_state_idx is not None: + pbackward_state = self.hooks_proxy[maybe_backward_state_idx] # type: ignore[index] + + # Copy-paste the AOT backward graph into the compiled autograd graph + def copy_paste_aot_backward_graph(): + def num_inputs(graph): + num_args = 0 + for node in graph.nodes: + if node.op == "placeholder": + num_args += 1 + continue + else: + break + return num_args + + # set up the proxy inputs to ctx._bw_module + # the calling convention is: [*symints, *args (primals and tangents), backward_state] + num_args = num_inputs(ctx._bw_module.graph) + pall_args = [ + pgrads[i] for i in range(num_args - int(pbackward_state is not None)) + ] + # replace the symints with our symints + symints = ctx._get_compiled_autograd_symints() + assert len(symints) == len(ctx.symints) + psymints = [self.to_proxy(e) for e in symints] + pall_args[: len(symints)] = psymints + # Add backward_state + if pbackward_state is not None: + pall_args.append(pbackward_state) + + # run over all nodes of the aot_backward graph. + # copy and paste them all into the compiled autograd graph. + args_idx = 0 + value_remap = {} + poutputs: Optional[list[torch.fx.Proxy]] = None + for node in ctx._bw_module.graph.nodes: + if node.op == "placeholder": + value_remap[node] = pall_args[args_idx].node + args_idx += 1 + elif node.op == "output": + assert len(node.args) == 1 + poutputs = [ + torch.fx.Proxy(value_remap[n], self.fx_tracer) + if isinstance(n, torch.fx.Node) + else n + for n in node.args[0] + ] + elif node.op == "get_attr": + name = node.target + qualname = self.fx_tracer.get_fresh_qualname(name) + setattr( + self.fx_tracer.root, qualname, getattr(ctx._bw_module, name) + ) + result = self.fx_tracer.create_node("get_attr", qualname, (), {}) + value_remap[node] = result + elif node.op == "call_function": + result = self.fx_tracer.graph.node_copy( + node, lambda n: value_remap[n] + ) + value_remap[node] = result + else: + raise AssertionError("shouldn't get here") + assert poutputs is not None + + # In general we don't know what the shapes of the outputs are, so allocate + # some dummy sizes for them. + def dummy(): + with disable_proxy_modes_tracing(): + return torch.zeros(0, 0, 0, 0, 123) + + outputs = [ + dummy() if isinstance(o, torch.fx.Proxy) else o for o in poutputs + ] + self.bind_tensors_to_proxies(outputs, poutputs) + return outputs + + outputs = copy_paste_aot_backward_graph() + + def proxy_subclass_constructor(subclass_meta, is_runtime, unwrapped_args): + @torch._dynamo.allow_in_graph + def make_subclass(*unwrapped_args): + return subclass_meta.creation_fn(unwrapped_args, is_runtime=is_runtime) + + punwrapped_args = pytree.tree_map(self.to_proxy, unwrapped_args) + + poutput = self.fx_tracer.create_proxy( + kind="call_function", + target=make_subclass, + args=tuple(punwrapped_args), + kwargs={}, + ) + + output = self.allocate_dummy() + self.bind_tensors_to_proxies([output], [poutput]) + return output + + results = torch._functorch._aot_autograd.runtime_wrappers._backward_epilogue_functional( + metadata, + maybe_subclass_metadata, + outputs, + make_subclass_override=proxy_subclass_constructor, + ) + presults = pytree.tree_map(self.to_proxy, results) + return presults + + def proxy_call_backward( + self, + inputs, + output_metadatas, + saved_tensors, + backward_idx: int, + ctx: torch.autograd.function.BackwardCFunction, + maybe_backward_state_idx: Optional[int], + ): + assert self.hooks_proxy is not None + pctx = self.hooks_proxy[backward_idx] # type: ignore[index] + pinputs = self.to_proxy(inputs) + psaved_tensors = self.to_proxy(saved_tensors) + if hasattr(ctx._forward_cls, "_aot_id"): # type: ignore[attr-defined] + # AOT backward + proxies = self.proxy_call_aot_backward( + pinputs, + psaved_tensors, + saved_tensors, + pctx, + ctx, + maybe_backward_state_idx, + ) + else: + proxies = self.fx_tracer.create_proxy( + kind="call_function", + target=call_backward, + args=( + pctx, + psaved_tensors, + *pinputs, + ), + kwargs={}, + ) + assert proxies is not None + with disable_proxy_modes_tracing(): # create fake Tensors grad_ins: list[Optional[torch.Tensor]] = [] - for output_metadata in output_metadatas: - if output_metadata is None: + for idx, output_metadata in enumerate(output_metadatas): + if output_metadata is None or proxies[idx] is None: grad_ins.append(None) continue @@ -232,6 +454,71 @@ def proxy_call_backward( self.bind_tensors_to_proxies(grad_ins, proxies) return tuple(grad_ins) + def call_copy_slices_prologue(self, inputs, base, view): + args = ( + inputs, + base.sizes(), + base.strides(), + base.storage_offset(), + view.sizes(), + view.strides(), + view.storage_offset(), + ) + return self.proxy_call(copy_slices_prologue, args, [None] * 3) + + def call_copy_slices_epilogue(self, needs_input_grad, result, res, grad_slice): + return self.proxy_call( + copy_slices_epilogue, + (needs_input_grad, result, res, grad_slice), + [None] * len(needs_input_grad), + ) + + def allocate_dummy(self): + with disable_proxy_modes_tracing(): + # Weird quantity so it's easy to grep + return torch.zeros([0, 123456789]) + + def bind_function(self, fn_name, fn, is_custom_function): + """Binds ops.fn_name = fn""" + return ops.add(fn_name, fn, is_custom_function) + + def apply_functional(self, fn_name, grads, args, output_metadata): + """Proxies a call to ops.fn_name(grads, *args) into the graph""" + op = ops.get(fn_name) + return self.proxy_call(op, (grads, *args), output_metadata) + + def proxy_call(self, fn, args, output_metadata): + """Proxies a call to fn(*args) into the graph""" + flat_args, _ = pytree.tree_flatten(args) + proxy_args = pytree.tree_map(lambda e: self.to_proxy(e), args) + proxy_out = self.fx_tracer.create_proxy( + "call_function", fn, args=proxy_args, kwargs={} + ) + result = [self.allocate_dummy() for _ in output_metadata] + self.bind_tensors_to_proxies(result, [proxy_out[i] for i in range(len(result))]) + return result + + def validate_outputs(self, _, outputs, args, output_metadata): + """Proxies a call to ops.validate_outputs(outputs, *args) into the graph""" + op = ops.get("validate_outputs") + proxy_args = pytree.tree_map(self.to_proxy, (outputs, *args)) + new_proxy_outputs = self.fx_tracer.create_proxy( + "call_function", op, args=proxy_args, kwargs={} + ) + assert len(output_metadata) == len(outputs) + self.bind_tensors_to_proxies(outputs, new_proxy_outputs) + return outputs + + def accumulate(self, old_var, new_var): + old_var_proxy = self.to_proxy(old_var) + new_var_proxy = self.to_proxy(new_var) + proxy_out = self.fx_tracer.create_proxy( + "call_function", torch.add, args=(old_var_proxy, new_var_proxy), kwargs={} + ) + result = self.allocate_dummy() + self.bind_tensors_to_proxies([result], [proxy_out]) + return result + def proxy_call_hook(self, hook, *args, **kwargs): return self.fx_tracer.create_proxy( "call_function", @@ -314,6 +601,7 @@ def move_graph_nodes_to_cuda(self, graph) -> list[int]: assert nodes[first_getitem_idx] == inputs_users[0] last_getitem_idx = first_getitem_idx + len(inputs_users) - 1 assert nodes[last_getitem_idx] == inputs_users[-1] + # getitem nodes on inputs for i, node in enumerate(inputs_users): if not has_cuda_inputs and node.meta["val"].device.type == "cuda": has_cuda_inputs = True @@ -323,9 +611,16 @@ def move_graph_nodes_to_cuda(self, graph) -> list[int]: is_scalar = len(node.meta["val"].size()) == 0 if is_cpu and is_scalar: node_users = list(node.users.keys()) + # We can only move the cpu scalar if it is not exposed to user code. if all( - isinstance(user.target, torch._ops.OpOverload) - and user.target.namespace in ("prims", "aten") + ( + isinstance(user.target, torch._ops.OpOverload) + and user.target.namespace in ("prims", "aten") + ) + or ( + isinstance(user.target, Op) + and not user.target.is_custom_function + ) for user in node_users ): # all users are prims/aten, can move safely @@ -335,6 +630,7 @@ def move_graph_nodes_to_cuda(self, graph) -> list[int]: # this is to handle the case where cudagraphs is enabled on a cpu-only graph if has_cuda_inputs: for node in to_move.values(): + verbose_log.debug("Moving node %s from cpu to cuda", node) node.meta["val"] = node.meta["val"].cuda() # return runtime indices we need to move to cuda @@ -368,7 +664,10 @@ def is_impure(node): or (node.op == "call_function" and node.target in _impure_targets) ) + before = len(self.fx_tracer.graph.nodes) self.fx_tracer.graph.eliminate_dead_code(is_impure) + after = len(self.fx_tracer.graph.nodes) + verbose_log.debug("DCE removed %d nodes", before - after) def end_capture(self, outputs): self.fx_tracer.create_proxy( @@ -384,6 +683,18 @@ def end_capture(self, outputs): (self.fx_tracer.create_arg(self.to_proxy(outputs)),), {}, ) + runtime_inputs_to_move: list[int] = [] + if snapshot_cudagraph_enabled(): + runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph) + + # We traced using dummy tensors. Delete all the metadata of the dummy tensors. + # It's probably better to refactor this class to use a different tracer + # than the make_fx tracer, but that is a larger change. + for node in self.fx_tracer.graph.nodes: + for field in ["tensor_meta", "example_value", "val"]: + if field in node.meta: + del node.meta[field] + self.rename_aot_dispatcher_nodes() self.reorder_tensor_pre_hook_nodes() self.reorder_pre_hook_nodes_to_schedule_asap() @@ -402,9 +713,6 @@ def end_capture(self, outputs): # Proper fix is Richard's Python compiled autograd effort which will avoid calling make_fx and # should prevent these ops from going into the CA graph. self.dce() - runtime_inputs_to_move: list[int] = [] - if snapshot_cudagraph_enabled(): - runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph) graph = GraphModule( self.fx_tracer.root, self.fx_tracer.graph, f"CompiledAutograd{self.id}" @@ -778,8 +1086,11 @@ def to_proxy(self, t): return [self.to_proxy(x) for x in t] if isinstance(t, tuple): return tuple(self.to_proxy(x) for x in t) - # can it be torch.SymInt as the code used to imply? - assert isinstance(t, torch.Tensor) + if isinstance(t, (torch.SymInt, torch.SymFloat)): + return self.symnode_proxy_lookup[t.node] + if not isinstance(t, torch.Tensor): + # constant types like device, dtype, str + return t proxy_tensor = fetch_object_proxy(self.fx_tracer, t) assert isinstance(proxy_tensor, torch.fx.experimental.proxy_tensor._ProxyTensor) return proxy_tensor.proxy @@ -921,3 +1232,39 @@ def reset() -> None: torch._C._dynamo.compiled_autograd.clear_cache() global COMPILE_COUNTER COMPILE_COUNTER = itertools.count() + + +# Reimplementation of part of CopySlices::apply in Python. +# The shared code is really similar so we're not going to try to deduplicate. +def copy_slices_prologue( + inputs, + base_sizes, + base_strides, + base_storage_offset, + view_sizes, + view_strides, + view_storage_offset, +): + grad = inputs[0] + result = grad.new_empty_strided(base_sizes, base_strides) + assert grad is not None + result.copy_(grad) + offset = view_storage_offset - base_storage_offset + grad_slice = result.as_strided(view_sizes, view_strides, offset) + return [result, grad_slice, grad_slice.clone(memory_format=torch.contiguous_format)] + + +# Reimplementation of part of CopySlices::apply in Python. +# The shared code is really similar so we're not going to try to deduplicate. +def copy_slices_epilogue(needs_input_grad, result, res, grad_slice): + grad_inputs = [None] * len(needs_input_grad) + for i in range(len(needs_input_grad)): + if needs_input_grad[i]: + if res[i] is None: + continue + if i == 0: + grad_slice.copy_(res[i]) + grad_inputs[i] = result + else: + grad_inputs[i] = res[i] + return grad_inputs diff --git a/torch/_dynamo/external_utils.py b/torch/_dynamo/external_utils.py index c50df4d969693..ec17a6b6d1354 100644 --- a/torch/_dynamo/external_utils.py +++ b/torch/_dynamo/external_utils.py @@ -116,6 +116,14 @@ def call_backward( return grads +def normalize_as_list(x: Any) -> list[Any]: + if isinstance(x, tuple): + return list(x) + elif isinstance(x, list): + return x + return [x] + + def untyped_storage_size(x: torch.Tensor) -> int: return x.untyped_storage().size() diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 02290ba55fc27..3ef4d94a1385c 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -72,6 +72,8 @@ def radians(x): def accumulate_grad(x, new_grad): + if new_grad is None: + return new_grad = torch.clone(new_grad) if x.grad is None: x.grad = new_grad diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 9e5b048d14947..ef634af4eb082 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3276,6 +3276,7 @@ def _module_dir(m: types.ModuleType): MOD_INLINELIST = [ "torch._decomp", "torch._dynamo._trace_wrapped_higher_order_op", + "torch._dynamo.compiled_autograd", "torch._dynamo.comptime", "torch._dynamo.polyfills", "torch._functorch._aot_autograd.subclass_parametrization", diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index 954865c76b19d..dc9e5af16da95 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -62,7 +62,6 @@ from .utils import ( call_func_at_runtime_with_args, make_boxed_func, - normalize_as_list, partial_flatten_asdict, strict_zip, ) @@ -1683,7 +1682,9 @@ def _backward_prologue_functional( # NOTE: this function must be torch._dynamo.allow_in_graph-able. Non tensor/symnode inputs must be constants. -def _backward_epilogue_functional(metadata, maybe_subclass_metadata, out): +def _backward_epilogue_functional( + metadata, maybe_subclass_metadata, out, *, make_subclass_override=None +): # Toss out the backward output tokens num_bw_tokens = metadata.num_backward_tokens if num_bw_tokens > 0: @@ -1703,6 +1704,7 @@ def _backward_epilogue_functional(metadata, maybe_subclass_metadata, out): subclass_metas=maybe_subclass_metadata.grad_input_metas, included_subclass_symints=True, is_runtime=True, + make_subclass_override=make_subclass_override, ) return outs_wrapped return out @@ -1728,6 +1730,13 @@ def process_runtime_tangent(x, meta: Union[PlainTensorMeta, SubclassCreationMeta expected_meta = meta.meta runtime_type = type(x) + if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + # When we're inside compiled autograd's AOTDispatcher step, + # regular Tensors look like FunctionalTensors. + # Tensor subclasses still look like Tensor subclasses though. + if isinstance(x, torch._subclasses.functional_tensor.FunctionalTensor): + runtime_type = torch.Tensor + runtime_meta = None runtime_subclass_keys: Sequence[str] = [] @@ -2001,23 +2010,9 @@ def backward(double_ctx, *args): @staticmethod def _backward_impl(ctx, all_args): - if ctx._is_compiled_autograd_tracing(): - if lazy_backward_info is None: - raise RuntimeError( - """This compiled backward function was saved by AOTAutogradCache, which does not support - compiled autograd. Please turn off AOTAutogradCache using `TORCHINDUCTOR_AUTOGRAD_CACHE=0`.""" - ) - bw_module = lazy_backward_info.bw_module - # For compiled autograd, run raw FX graph so that it can be inlined into the larger graph - symints = ctx._get_compiled_autograd_symints() - assert len(symints) == len(ctx.symints) - all_args[: len(symints)] = symints - if backward_state_indices: - assert ctx._compiled_autograd_backward_state.proxy is not None - all_args.append(ctx._compiled_autograd_backward_state) - context = torch._C._DisableAutocast if disable_amp else nullcontext - with context(): - return normalize_as_list(bw_module(*all_args)) + assert ( + not ctx._is_compiled_autograd_tracing() + ), "compiled autograd reimplements this function at proxy_call_aot_backward" assert ( not backward_state_indices diff --git a/torch/_functorch/_aot_autograd/subclass_utils.py b/torch/_functorch/_aot_autograd/subclass_utils.py index 2ff94f344c63f..d352f43da371b 100644 --- a/torch/_functorch/_aot_autograd/subclass_utils.py +++ b/torch/_functorch/_aot_autograd/subclass_utils.py @@ -8,7 +8,7 @@ import collections import typing from collections.abc import Iterable -from typing import Any, Optional, TypeVar, Union +from typing import Any, Callable, Optional, TypeVar, Union import torch import torch.utils._pytree as pytree @@ -326,6 +326,7 @@ def wrap_tensor_subclasses( num_fw_outs_saved_for_bw: Optional[int] = None, included_subclass_symints: bool = False, is_runtime: bool = False, + make_subclass_override: Optional[Callable] = None, ) -> tuple[Any, ...]: wrapped_args = [] num_args_tallied = 0 @@ -336,9 +337,15 @@ def wrap_tensor_subclasses( else: assert isinstance(subclass_meta, SubclassCreationMeta) assert subclass_meta.included_subclass_symints == included_subclass_symints - wrapped_args.append( - subclass_meta.creation_fn(unwrapped_args, is_runtime=is_runtime) - ) + + if make_subclass_override: + wrapped_args.append( + make_subclass_override(subclass_meta, is_runtime, unwrapped_args) + ) + else: + wrapped_args.append( + subclass_meta.creation_fn(unwrapped_args, is_runtime=is_runtime) + ) num_args_tallied += subclass_meta.arg_count # Note: [Partitioner handling for Subclasses, Part 2] diff --git a/torch/autograd/function.py b/torch/autograd/function.py index ca7d70192ca57..1bcb25751454b 100644 --- a/torch/autograd/function.py +++ b/torch/autograd/function.py @@ -334,6 +334,9 @@ def __init__(cls, name, bases, attrs): backward_fn._compiled_autograd_should_lift = attrs.get( # type: ignore[attr-defined] "_compiled_autograd_should_lift", True ) + backward_fn._bw_module = None # type: ignore[attr-defined] + if getattr(cls, "_lazy_backward_info", None): + backward_fn._bw_module = cls._lazy_backward_info.bw_module # type: ignore[attr-defined] cls._backward_cls = backward_fn super().__init__(name, bases, attrs) diff --git a/torch/csrc/autograd/custom_function.cpp b/torch/csrc/autograd/custom_function.cpp index 48e79f9ec5044..1df7efb35d10d 100644 --- a/torch/csrc/autograd/custom_function.cpp +++ b/torch/csrc/autograd/custom_function.cpp @@ -503,6 +503,16 @@ void check_variable_result( } } +AutogradContext::AutogradContext(PackedArgs& packed_args) { + saved_data = packed_args.unpack_saved_data(); + saved_variables_override_ = packed_args.unpack(); + // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer) + materialize_grads_ = packed_args.unpack(); + // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer) + has_freed_buffers_ = packed_args.unpack(); + needs_input_grad_override_ = packed_args.unpack>(); +} + void AutogradContext::save_for_backward(variable_list to_save) { to_save_ = std::move(to_save); } @@ -527,6 +537,9 @@ void AutogradContext::save_variables() { variable_list AutogradContext::get_saved_variables() const { TORCH_CHECK(!has_freed_buffers_, ERR_BACKWARD_TWICE); + if (saved_variables_override_.has_value()) { + return *saved_variables_override_; + } variable_list saved; saved.reserve(saved_variables_.size()); auto ptr = grad_fn_.lock(); @@ -538,6 +551,9 @@ variable_list AutogradContext::get_saved_variables() const { } bool AutogradContext::needs_input_grad(size_t output_edge_index) const { + if (needs_input_grad_override_.has_value()) { + return needs_input_grad_override_.value().at(output_edge_index); + } auto ptr = grad_fn_.lock(); TORCH_INTERNAL_ASSERT(ptr); return ptr->task_should_compute_output(output_edge_index); @@ -545,6 +561,15 @@ bool AutogradContext::needs_input_grad(size_t output_edge_index) const { bool AutogradContext::needs_input_grad( std::initializer_list idxs) const { + if (needs_input_grad_override_.has_value()) { + return std::any_of(idxs.begin(), idxs.end(), [this](IndexRange range) { + bool result = false; + for (const auto i : c10::irange(range.first, range.second)) { + result |= needs_input_grad_override_.value().at(i); + } + return result; + }); + } auto ptr = grad_fn_.lock(); TORCH_INTERNAL_ASSERT(ptr); return ptr->task_should_compute_output(idxs); diff --git a/torch/csrc/autograd/custom_function.h b/torch/csrc/autograd/custom_function.h index a142f082e831e..34f508ac72abe 100644 --- a/torch/csrc/autograd/custom_function.h +++ b/torch/csrc/autograd/custom_function.h @@ -126,6 +126,8 @@ struct TORCH_API AutogradContext { AutogradContext& operator=(AutogradContext&& other) = delete; ~AutogradContext() = default; + AutogradContext(PackedArgs& packed_args); + /// Can be used to save non-variable data for `backward`. ska::flat_hash_map saved_data; @@ -169,12 +171,103 @@ struct TORCH_API AutogradContext { std::weak_ptr grad_fn_; bool has_freed_buffers_{false}; + // Compiled autograd overrides saved_variables() and needs_input_grad(). + // We store the values we want to return here. + std::optional saved_variables_override_; + std::optional> needs_input_grad_override_; + void save_variables(); template friend struct CppNode; + template + friend variable_list CppNode_apply_functional( + variable_list&& inputs, + AutogradContext& ctx_, + const std::vector& is_variable_input_, + const std::vector& output_info_, + const std::string& name); }; +template +inline variable_list CppNode_apply_functional( + // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) + variable_list&& inputs, + AutogradContext& ctx_, + const std::vector& is_variable_input_, + const std::vector& output_info_, + const std::string& name) { + at::OptionalDeviceGuard _device_guard; + + auto num_inputs = inputs.size(); + variable_list backward_inputs; + backward_inputs.reserve(num_inputs); + for (const auto i : c10::irange(num_inputs)) { + if (inputs[i].defined() || !ctx_.materialize_grads_) { + backward_inputs.emplace_back(std::move(inputs[i])); + } else { + backward_inputs.emplace_back(output_info_[i].zeros(_device_guard)); + } + } + + auto outputs = T::backward(&ctx_, backward_inputs); + + const auto num_forward_inputs = + static_cast(is_variable_input_.size()); + auto num_outputs = static_cast(outputs.size()); + // Returning too many results is ok, but only as long as they're all + // undefined. Truncate the result vector in that case. + if (num_outputs > num_forward_inputs) { + bool all_undef = true; + for (const auto i : c10::irange(num_forward_inputs, num_outputs)) { + all_undef &= (!outputs[i].defined()); + } + if (all_undef) { + outputs.resize(num_forward_inputs); + num_outputs = num_forward_inputs; + } + } + + if (num_outputs != num_forward_inputs) { + std::string msg("function "); + msg += name + " returned an incorrect number of gradients (expected "; + msg += std::to_string(num_forward_inputs) + ", got "; + msg += std::to_string(num_outputs) + ")"; + throw std::runtime_error(msg); + } + + variable_list results; + results.reserve(num_outputs); + for (const auto i : c10::irange(num_outputs)) { + if (!is_variable_input_[i]) { + if (outputs[i].defined()) { + std::string msg("function "); + msg += name + + " returned a gradient different that is defined at position "; + msg += std::to_string(i + 1) + + ", std the corresponding forward input was not a Variable"; + throw std::runtime_error(msg); + } + continue; + } + results.emplace_back(outputs[i]); + } + return results; +} + +template +inline variable_list CppNode_apply_functional_ivalue( + const variable_list& inputs, + const std::vector& args) { + auto packed_args = PackedArgs(args); + auto ctx = AutogradContext(packed_args); + auto output_info = packed_args.unpack>(); + auto is_variable_input = packed_args.unpack>(); + auto name = packed_args.unpack(); + return CppNode_apply_functional( + variable_list(inputs), ctx, is_variable_input, output_info, name); +} + // CppNode is the Node in the autograd graph that represents the user defined // backward function for Function. Calls to CppNode::apply are forward to // T::backward(). @@ -232,7 +325,64 @@ struct CppNode : public Node { saved.before(ctx_.has_freed_buffers_); saved.before(input_info_); saved.before(output_info_); - auto results = apply(variable_list(inputs)); + + PackedArgs packed_args; + packed_args.pack_saved_data(ctx_.saved_data); + variable_list saved_variables = ctx_.get_saved_variables(); + packed_args.pack(saved_variables); + packed_args.pack(ctx_.materialize_grads_); + packed_args.pack(ctx_.has_freed_buffers_); + + std::vector needs_input_grad; + { + auto ptr = ctx_.grad_fn_.lock(); + TORCH_INTERNAL_ASSERT(ptr); + for (const auto i : c10::irange(ptr->next_edges().size())) { + needs_input_grad.push_back(ptr->task_should_compute_output(i)); + } + } + packed_args.pack(needs_input_grad); + + packed_args.pack(output_info_); + packed_args.pack(is_variable_input_); + packed_args.pack(name()); + auto args = std::move(packed_args).vec(); + + auto output_metadata = torch::dynamo::autograd:: + IValuePacker>>::pack( + torch::dynamo::autograd::get_input_metadata(next_edges())); + + const auto& pyinterface = torch::dynamo::autograd::getPyCompilerInterface(); + + // Each time apply_with_saved is called, we bind a new function to Python. + // This is because the schema might be different on compiled autograd cache + // misses. An alternative is to pass the schema to Python so that it can be + // an input to a function, but the schema can't be put into an FX graph + // right now. + std::vector schema; + schema.reserve(args.size()); + for (const auto& ivalue : args) { + if (ivalue.isTensor()) { + schema.emplace_back(at::TensorType::get()); + } else { + schema.emplace_back(ivalue.type()); + } + } + auto fn_name = pyinterface->bind_function( + saved.get_py_compiler(), + std::string(typeid(T).name()), + CppNode_apply_functional_ivalue, + schema, + /*is_custom_function*/ true); + + auto results = pyinterface->call_function( + saved.get_py_compiler(), + "apply_functional", + fn_name, + inputs, + args, + output_metadata); + saved.after(ctx_.saved_data); TORCH_INTERNAL_ASSERT(ctx_.non_differentiable_.empty()); TORCH_INTERNAL_ASSERT(ctx_.dirty_inputs_.empty()); @@ -403,68 +553,13 @@ auto Function::apply(Args&&... args) template // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) variable_list CppNode::apply(variable_list&& inputs) { - at::OptionalDeviceGuard _device_guard; - - auto num_inputs = inputs.size(); - variable_list backward_inputs; - backward_inputs.reserve(num_inputs); - for (const auto i : c10::irange(num_inputs)) { - if (inputs[i].defined() || !ctx_.materialize_grads_) { - backward_inputs.emplace_back(std::move(inputs[i])); - } else { - backward_inputs.emplace_back(output_info_[i].zeros(_device_guard)); - } - } - // Acquire lock to here protect thread safety on custom C++ Autograd Node // This is needed for the custom Autograd Node since we don't know if the // user defined Node will write to the shared data during backward. // see Note [Thread Safety on Autograd Node] std::lock_guard lock(mutex_); - - auto outputs = T::backward(&ctx_, backward_inputs); - - const auto num_forward_inputs = - static_cast(is_variable_input_.size()); - auto num_outputs = static_cast(outputs.size()); - // Returning too many results is ok, but only as long as they're all - // undefined. Truncate the result vector in that case. - if (num_outputs > num_forward_inputs) { - bool all_undef = true; - for (const auto i : c10::irange(num_forward_inputs, num_outputs)) { - all_undef &= (!outputs[i].defined()); - } - if (all_undef) { - outputs.resize(num_forward_inputs); - num_outputs = num_forward_inputs; - } - } - - if (num_outputs != num_forward_inputs) { - std::string msg("function "); - msg += name() + " returned an incorrect number of gradients (expected "; - msg += std::to_string(num_forward_inputs) + ", got "; - msg += std::to_string(num_outputs) + ")"; - throw std::runtime_error(msg); - } - - variable_list results; - results.reserve(num_outputs); - for (const auto i : c10::irange(num_outputs)) { - if (!is_variable_input_[i]) { - if (outputs[i].defined()) { - std::string msg("function "); - msg += name() + - " returned a gradient different that is defined at position "; - msg += std::to_string(i + 1) + - ", std the corresponding forward input was not a Variable"; - throw std::runtime_error(msg); - } - continue; - } - results.emplace_back(outputs[i]); - } - return results; + return CppNode_apply_functional( + std::move(inputs), ctx_, is_variable_input_, output_info_, name()); } template diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 2aad75e0e74b9..c8d465211a6f8 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -897,6 +897,19 @@ bool has_input_metadata(const Edge& thing) { return thing.is_valid(); } +std::vector> collect_input_metadata( + const edge_list& edges) { + std::vector> input_metadata; + for (const auto& edge : edges) { + if (!edge.is_valid()) { + input_metadata.emplace_back(std::nullopt); + continue; + } + input_metadata.emplace_back(edge.function->input_metadata(edge.input_nr)); + } + return input_metadata; +} + // Given an vector or vector>, validate the // outputs. This involves using the InputMetadata to check the outputs and also // potentially calling .sum_to on the outputs. diff --git a/torch/csrc/autograd/engine.h b/torch/csrc/autograd/engine.h index 4243f1b1d6ee5..5bf00bac5378e 100644 --- a/torch/csrc/autograd/engine.h +++ b/torch/csrc/autograd/engine.h @@ -47,6 +47,8 @@ TORCH_API void validate_outputs( const std::vector>& input_metadata, variable_list& grads, const std::function& format_error); +TORCH_API std::vector> collect_input_metadata( + const edge_list& edges); struct NodeTask { std::weak_ptr base_; diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index ba2f6edbc6c0c..abd11303eafe8 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -34,8 +34,12 @@ using tensor_list = std::vector; using variable_list = std::vector; using edge_list = std::vector; using saved_variable_list = std::vector; +using ivalue_list = std::vector; +using functional_apply_t = std::function< + variable_list(const variable_list&, const std::vector&)>; using IndexRange = std::pair; using torch::dynamo::autograd::CompiledNodeArgs; +using torch::dynamo::autograd::PackedArgs; using torch::dynamo::autograd::SwapSavedVariables; // Custom deleter to prevent stack overflows. @@ -604,6 +608,12 @@ struct TORCH_API Node : std::enable_shared_from_this { std::string("apply_with_saved not implemented: ") + name()); } + // If this node is the AOTBackward node produced by torch.compile. + // Compiled Autograd special-cases on this information. + virtual bool is_aot_backward() const { + return false; + } + protected: /// Performs the `Node`'s actual operation. virtual variable_list apply(variable_list&& inputs) = 0; diff --git a/torch/csrc/autograd/function_hook.h b/torch/csrc/autograd/function_hook.h index 6342bf280a5ce..4e8bba79a169b 100644 --- a/torch/csrc/autograd/function_hook.h +++ b/torch/csrc/autograd/function_hook.h @@ -8,6 +8,7 @@ namespace torch::dynamo::autograd { class CompiledNodeArgs; class SwapSavedVariables; +struct PackedArgs; } // namespace torch::dynamo::autograd // A hook that's called on gradients diff --git a/torch/csrc/autograd/functions/tensor.cpp b/torch/csrc/autograd/functions/tensor.cpp index 1c13b60ca7831..a06ebefb85c3c 100644 --- a/torch/csrc/autograd/functions/tensor.cpp +++ b/torch/csrc/autograd/functions/tensor.cpp @@ -16,6 +16,8 @@ namespace torch::autograd { +using torch::dynamo::autograd::IValuePacker; + static variable_list CopyBackwards_apply_functional( variable_list&& grads, std::array needs_input_grad, @@ -41,6 +43,16 @@ static variable_list CopyBackwards_apply_functional( return grad_inputs; } +static variable_list CopyBackwards_apply_functional_ivalue( + const variable_list& grads, + const ivalue_list& args) { + PackedArgs r(args); + auto needs_input_grad = r.unpack>(); + auto src_options = r.unpack(); + return CopyBackwards_apply_functional( + variable_list(grads), needs_input_grad, src_options); +} + auto CopyBackwards::apply(variable_list&& grads) -> variable_list { return CopyBackwards_apply_functional( std::move(grads), @@ -51,11 +63,43 @@ auto CopyBackwards::apply(variable_list&& grads) -> variable_list { void CopyBackwards::compiled_args(CompiledNodeArgs& args) { args.collect(src_options); } + variable_list CopyBackwards::apply_with_saved( const variable_list& inputs, SwapSavedVariables& saved) { saved.before(src_options); - auto result = apply(variable_list(inputs)); + + static c10::once_flag flag; + c10::call_once(flag, [&]() { + std::vector schema = { + IValuePacker>::packed_type(), + IValuePacker::packed_type()}; + const auto& interface = torch::dynamo::autograd::getPyCompilerInterface(); + interface->bind_function( + saved.get_py_compiler(), + name(), + CopyBackwards_apply_functional_ivalue, + schema); + }); + + PackedArgs packed_args; + packed_args.pack>( + {task_should_compute_output(0), task_should_compute_output(1)}); + packed_args.pack(src_options); + + auto output_metadata = torch::dynamo::autograd:: + IValuePacker>>::pack( + torch::dynamo::autograd::get_input_metadata(next_edges())); + + const auto& interface = torch::dynamo::autograd::getPyCompilerInterface(); + auto result = interface->call_function( + saved.get_py_compiler(), + "apply_functional", + name(), + inputs, + std::move(packed_args).vec(), + output_metadata); + saved.after(src_options); return result; } @@ -80,38 +124,7 @@ CopySlices::CopySlices( } } -// common code between apply/apply_with_saved -template -inline variable_list CopySlices::apply_impl( - variable_list&& inputs, - const T& call_fn) { - check_input_variables("CopySlices", inputs, 1, -1, true); - auto& grad = std::move(inputs)[0]; - if (!grad.defined()) { - return variable_list(num_outputs()); - } - - // Acquire lock to here protect thread safety on fn - // see Note [Thread Safety on Autograd Node] - std::lock_guard lock(mutex_); - - if (!fn) { - throw std::runtime_error(ERR_BACKWARD_TWICE); - } - - auto result = - grad.new_empty_strided_symint(base.sym_sizes(), base.sym_strides()); - result.copy_(grad); - - at::Tensor grad_slice; - if (view_fn) { - grad_slice = (*view_fn)(result); - } else { - auto offset = view.sym_storage_offset() - base.sym_storage_offset(); - grad_slice = - result.as_strided_symint(view.sym_sizes(), view.sym_strides(), offset); - } - +void CopySlices::update_exec_info() { // See Note [View + Inplace update for view tensor] For more details on this // block Since the gradient edge for the 0th input is different between `this` // and `fn`, make sure that the one from `fn` has the same metadata in the @@ -154,6 +167,41 @@ inline variable_list CopySlices::apply_impl( TORCH_INTERNAL_ASSERT( fn->next_edge(i).function.get() == this->next_edge(i).function.get()); } +} + +// common code between apply/apply_with_saved +template +inline variable_list CopySlices::apply_impl( + variable_list&& inputs, + const T& call_fn) { + check_input_variables("CopySlices", inputs, 1, -1, true); + auto& grad = std::move(inputs)[0]; + if (!grad.defined()) { + return variable_list(num_outputs()); + } + + // Acquire lock to here protect thread safety on fn + // see Note [Thread Safety on Autograd Node] + std::lock_guard lock(mutex_); + + if (!fn) { + throw std::runtime_error(ERR_BACKWARD_TWICE); + } + + auto result = + grad.new_empty_strided_symint(base.sym_sizes(), base.sym_strides()); + result.copy_(grad); + + at::Tensor grad_slice; + if (view_fn) { + grad_slice = (*view_fn)(result); + } else { + auto offset = view.sym_storage_offset() - base.sym_storage_offset(); + grad_slice = + result.as_strided_symint(view.sym_sizes(), view.sym_strides(), offset); + } + + update_exec_info(); // TODO: We clone grad_slice because we modify it below and "fn" might save // it for the backward of res. We might be able to avoid the clone() if @@ -201,17 +249,38 @@ variable_list CopySlices::apply_with_saved( SwapSavedVariables& saved) { saved.before(base); saved.before(view); - int call_count = 0; - variable_list result = apply_impl( - variable_list(grads), - [this, &saved, &call_count](const variable_list& inputs2) { - call_count++; - return fn->apply_with_saved(inputs2, saved); - }); - TORCH_INTERNAL_ASSERT(call_count == 1); + + auto results = variable_list(num_outputs()); + if (grads[0].defined()) { + if (!fn) { + throw std::runtime_error(ERR_BACKWARD_TWICE); + } + update_exec_info(); + + std::vector needs_input_grad; + for (const auto i : c10::irange(num_outputs())) { + needs_input_grad.emplace_back(task_should_compute_output(i)); + } + // Not yet supported, also doesn't happen in typical eager mode execution + // (this only happens by default with torch-xla). + TORCH_INTERNAL_ASSERT(!view_fn); + const auto& interface = torch::dynamo::autograd::getPyCompilerInterface(); + variable_list stuff = interface->call_copy_slices_prologue( + saved.get_py_compiler(), grads, base, view); + TORCH_INTERNAL_ASSERT(stuff.size() == 3); + // These variables are named the same as in CopySlices::apply_impl. + // Follow along there. + auto result = stuff[0]; + auto grad_slice = stuff[1]; + auto grad_slice_clone = stuff[2]; + auto res = fn->apply_with_saved({grad_slice_clone}, saved); + results = interface->call_copy_slices_epilogue( + saved.get_py_compiler(), needs_input_grad, result, res, grad_slice); + } + saved.after(base); saved.after(view); - return result; + return results; } auto CopySlices::apply(variable_list&& inputs1) -> variable_list { diff --git a/torch/csrc/autograd/functions/tensor.h b/torch/csrc/autograd/functions/tensor.h index e812860e3adc1..4b0c2190ed542 100644 --- a/torch/csrc/autograd/functions/tensor.h +++ b/torch/csrc/autograd/functions/tensor.h @@ -172,6 +172,7 @@ struct TORCH_API CopySlices : public Node { variable_list apply_with_saved( const variable_list& inputs, SwapSavedVariables& saved) override; + void update_exec_info(); at::TensorGeometry base; // view and view_fn are redundant and view_fn will be used if available. diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 1c1d4f504af23..9a8375480374d 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -131,6 +131,11 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { if (!ParameterClass) return nullptr; + py::class_(m, "TensorGeometry") + .def("sizes", &at::TensorGeometry::sizes) + .def("strides", &at::TensorGeometry::strides) + .def("storage_offset", &at::TensorGeometry::storage_offset); + py::class_(m, "ProfilerEvent") .def("kind", &LegacyEvent::kindStr) .def("name", [](const LegacyEvent& e) { return e.name(); }) diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index 19151cbaafe63..dd0b7a927bfee 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -30,6 +30,7 @@ #include #include #include +#include #include #include @@ -237,16 +238,23 @@ auto PyNode::defer_to_dynamo( TORCH_INTERNAL_ASSERT( _backward_idx.has_value(), "indices should already be set by compiled_args, called before apply_with_saved"); - TORCH_INTERNAL_ASSERT(!_backward_state_idx.has_value()); + PyObject* backward_state_idx = Py_None; + if (_backward_state_idx.has_value()) { + backward_state_idx = THPUtils_packInt64(_backward_state_idx.value()); + // this might be simplifiable now that we no longer inline + Py_CLEAR(py_fn->compiled_autograd_backward_state); + } THPObjectPtr r(PyObject_CallMethod( // NOLINTNEXTLINE(bugprone-unchecked-optional-access) compiler.value(), "proxy_call_backward", - "OOOi", + "OOOiOO", pyInputs.get(), fwdInputMetadatas.get(), saved_tensors.get(), - *_backward_idx)); + *_backward_idx, + obj, + backward_state_idx)); if (!r) throw_python_error(); @@ -288,6 +296,11 @@ auto PyNode::name() const -> std::string { return name; } +bool PyNode::is_aot_backward() const { + py::handle handle(obj); + return py::hasattr(py::getattr(handle, "_forward_cls"), "_aot_id"); +} + auto PyNode::compiled_autograd_should_lift() const -> bool { pybind11::gil_scoped_acquire gil; static PyObject* attr_name = @@ -340,11 +353,8 @@ void PyNode::compiled_args(CompiledNodeArgs& args) { args.collect(f->output_info); args.collect(f->input_info); - if (compiled_autograd_should_lift()) { - Py_INCREF(obj); - _backward_idx = - args.add_backward(c10::SafePyObject(obj, getPyInterpreter())); - } + Py_INCREF(obj); + _backward_idx = args.add_backward(c10::SafePyObject(obj, getPyInterpreter())); PyObject* bw_state = f->compiled_autograd_backward_state; if (args.cond(bw_state != nullptr)) { @@ -366,28 +376,8 @@ variable_list PyNode::apply_with_saved( saved.before(f->output_info); saved.before(f->input_info); f->compiled_autograd_tracing = true; - variable_list result; - if (!compiled_autograd_should_lift()) { - if (_backward_state_idx.has_value()) { - PyObject* r = PyObject_CallMethod( - saved.get_py_compiler(), - "bind_backward_state", - "i", - *_backward_state_idx); - if (r == nullptr) { - throw python_error(); - } - THPObjectPtr prior(f->compiled_autograd_backward_state); - f->compiled_autograd_backward_state = r; - result = apply(variable_list(inputs)); - Py_CLEAR(f->compiled_autograd_backward_state); - f->compiled_autograd_backward_state = prior.release(); - } else { - result = apply(variable_list(inputs)); - } - } else { - result = defer_to_dynamo(variable_list(inputs), saved.get_py_compiler()); - } + variable_list result = + defer_to_dynamo(variable_list(inputs), saved.get_py_compiler()); f->compiled_autograd_tracing = false; saved.after(f->compiled_autograd_symints); saved.after(f->saved_variables); @@ -1092,6 +1082,7 @@ PyObject* process_outputs( THPFunction* grad_fn, const UnpackedInput& unpacked, PyObject* inputs, + // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) THPObjectPtr&& raw_output, bool is_executable, torch::jit::Node* node, diff --git a/torch/csrc/autograd/python_function.h b/torch/csrc/autograd/python_function.h index 46faff8e46865..2f28c765ab069 100644 --- a/torch/csrc/autograd/python_function.h +++ b/torch/csrc/autograd/python_function.h @@ -43,6 +43,8 @@ struct PyNode : public Node { std::string name() const override; bool is_traceable() override; + bool is_aot_backward() const override; + void compiled_args(CompiledNodeArgs& args) override; variable_list apply_with_saved( const variable_list& inputs, diff --git a/torch/csrc/dynamo/compiled_autograd.cpp b/torch/csrc/dynamo/compiled_autograd.cpp new file mode 100644 index 0000000000000..7e2aad576189f --- /dev/null +++ b/torch/csrc/dynamo/compiled_autograd.cpp @@ -0,0 +1,27 @@ +#include +#include + +namespace torch::dynamo::autograd { + +std::unique_ptr kPyCompilerInterface; + +const std::unique_ptr& getPyCompilerInterface() { + TORCH_INTERNAL_ASSERT(kPyCompilerInterface != nullptr); + return kPyCompilerInterface; +} + +void setPyCompilerInterface(std::unique_ptr&& impl) { + TORCH_INTERNAL_ASSERT(impl != nullptr); + kPyCompilerInterface = std::move(impl); +} + +void resetPyCompilerInterface() { + kPyCompilerInterface.reset(); +} + +std::vector> get_input_metadata( + const edge_list& edges) { + return torch::autograd::collect_input_metadata(edges); +} + +} // namespace torch::dynamo::autograd diff --git a/torch/csrc/dynamo/compiled_autograd.h b/torch/csrc/dynamo/compiled_autograd.h index 383cff14b8e05..b70576ef03129 100644 --- a/torch/csrc/dynamo/compiled_autograd.h +++ b/torch/csrc/dynamo/compiled_autograd.h @@ -900,6 +900,542 @@ class SwapSavedVariables { StashedVars stashed_ivalues; }; +// NOTE: [Compiled Autograd and backward functions] +// Built-in autograd nodes have functional apply variants +// (e.g. MulBackward0_apply_functional). Compiled Autograd's initial graph +// capture wants to take a variant of this function and proxy it into the graph. +// Every autograd node defines an apply_with_saved function, that when invoked, +// proxys a call to a function into the Compiled Autograd graph. +// +// Some requirements that we have are: +// - The proxy'ed function must have inputs that are FX-graphable types. +// - Windows has a DLL symbol limit of 65536. +// - Node::apply_with_saved is in libtorch_cpu which does not have direct access +// to Python +// +// There were multiple ways to skin the cat, but what we end up doing is: +// - for e.g. MulBackward0_apply_functional, we create a new C++ function +// MulBackward0_apply_functional_ivalue that accepts vector. +// - We define how to pack and unpack arbitrary C++ types into IValues. +// - apply_with_saved passes MulBackward0_apply_functional_ivalue and +// the IValue arguments to Python via an indirection. +// In Python, these get proxy'ed into a graph. + +// Helper struct for packing/unpacking an arbitrary C++ type into a single +// IValue. There are various full and partial specializations for IValuePacker +// to handle packing specific types (like TensorOptions) into an IValue. +template +struct IValuePacker { + // Defines how to pack T into an IValue. + static at::IValue pack(const T& t) { + return t; + } + // Defines how to unpack an IValue into T. + static T unpack(const at::IValue& t) { + return t.to(); + } + // Returns the TypePtr for the IValue (this is like the "type" of the IValue). + // We use this when passing the packed IValue from Python to C++. + // In Python, the IValue is just a PyObject* with the native type. + // For example, it may be a Python int, a Python List[int], etc. + // When passing this PyObject* into C++, we need to know how to parse it + // into a C++ type that then gets put into an IValue. + // That's what the TypePtr is for: it contains the information to do the + // parsing. See torch::jit::toIValue for more information. + static at::TypePtr packed_type() { + if constexpr (::std::is_same_v) { + return at::TensorType::get(); + } else if constexpr (::std::is_same_v) { + return at::IntType::get(); + } else if constexpr (::std::is_same_v) { + return at::SymIntType::get(); + } else if constexpr (::std::is_same_v) { + return at::BoolType::get(); + } else if constexpr (::std::is_same_v) { + return at::FloatType::get(); + } else if constexpr (::std::is_same_v) { + return at::SymFloatType::get(); + } else if constexpr (::std::is_same_v) { + return at::SymBoolType::get(); + } else if constexpr (::std::is_same_v) { + return at::LayoutType::get(); + } else if constexpr (::std::is_same_v) { + return at::StringType::get(); + } else if constexpr (::std::is_same_v) { + return at::DeviceObjType::get(); + } else if constexpr (::std::is_same_v) { + return at::NumberType::get(); + } else if constexpr (::std::is_same_v) { + return at::MemoryFormatType::get(); + } else if constexpr (::std::is_same_v) { + return at::ScalarTypeType::get(); + } else { + // If you got here, you have probably added a member of a new type + // to a built-in C++ autograd node. + // Unfortunately, we don't know how to handle this type yet. + // To get this new type to work with Compiled Autograd, please + // either change it to be an IValue-constructible type, or + // define how to pack and unpack an object of this time into an IValue + // by creating a specialization of IValuePacker for this type. + // See NOTE: [Compiled Autograd and backward functions] for context. + TORCH_INTERNAL_ASSERT(false, "IValuePacker not implemented for type"); + return at::NoneType::get(); + } + } +}; + +template <> +struct IValuePacker { + static at::IValue pack(const size_t& t) { + // We generally use size_t as the size of a list of Tensors or number of + // dimensions. The number of dimensions generally do not exceed 64 + // (TensorIterator has that limitation), and lists of Tensors generally do + // not exceed the int64_t max (you'd probably run out of RAM or run into + // significant Tensor overhead). If you run into this limitation the fix is + // to figure out how to pack size_t into int64_t. Note that size_t has some + // weird behavior on Mac OS. + uint64_t maximum_value = std::numeric_limits::max(); + TORCH_INTERNAL_ASSERT( + static_cast(t) <= maximum_value, + "size_t too large to pack into IValue"); + return static_cast(t); // pack as int64_t + } + static size_t unpack(const at::IValue& t) { + return static_cast(t.toInt()); + } + static at::TypePtr packed_type() { + return IValuePacker::packed_type(); + } +}; + +template <> +struct IValuePacker> { + static at::IValue pack(const std::vector& t) { + return t; + } + static std::vector unpack(const at::IValue& t) { + // We need this because there's no t.to>() override? + return t.toSymIntVector(); + } + static at::TypePtr packed_type() { + return at::ListType::create(at::SymIntType::get()); + } +}; + +template <> +struct IValuePacker { + static at::IValue pack(const VariableInfo& t) { + auto tuple = std::make_tuple( + t.layout, t.device, t.scalar_type, t.size, t.requires_grad, t.is_empty); + return tuple; + } + static VariableInfo unpack(const at::IValue& t) { + auto tuple = t.toTuple(); + const auto& tuple_elements = tuple->elements(); + const auto elements = tuple_elements.asArrayRef(); + TORCH_INTERNAL_ASSERT(elements.size() == 6); + VariableInfo v; + v.layout = elements[0].toLayout(); + v.device = elements[1].toDevice(); + v.scalar_type = elements[2].toScalarType(); + v.size = elements[3].toSymIntVector(); + v.requires_grad = elements[4].toBool(); + v.is_empty = elements[5].toBool(); + return v; + } + static at::TypePtr packed_type() { + return at::TupleType::create({ + at::LayoutType::get(), + at::DeviceObjType::get(), + at::ScalarTypeType::get(), + at::ListType::create(at::SymIntType::get()), + at::BoolType::get(), + at::BoolType::get(), + }); + } +}; + +template <> +struct IValuePacker { + static at::IValue pack(const caffe2::TypeMeta& t) { + return at::typeMetaToScalarType(t); // pack as at::ScalarType + } + static caffe2::TypeMeta unpack(const at::IValue& t) { + return caffe2::TypeMeta::fromScalarType(t.to()); + } + static at::TypePtr packed_type() { + return IValuePacker::packed_type(); + } +}; + +inline std::optional optTypeMetaToScalarType( + const std::optional& t) { + if (t.has_value()) { + return at::typeMetaToScalarType(t.value()); + } else { + return std::nullopt; + } +} + +using packed_tensoroptions_t = std::tuple< + std::optional, + std::optional, + std::optional, + std::optional, + std::optional, + std::optional>; + +inline packed_tensoroptions_t pack_TensorOptions(const at::TensorOptions& t) { + auto tuple = std::make_tuple( + t.requires_grad_opt(), + t.memory_format_opt(), + t.device_opt(), + optTypeMetaToScalarType(t.dtype_opt()), + t.layout_opt(), + t.pinned_memory_opt()); + return tuple; +} +inline at::TensorOptions unpack_TensorOptions( + const packed_tensoroptions_t& tuple) { + at::TensorOptions result; + auto maybe_requires_grad = std::get<0>(tuple); + if (maybe_requires_grad.has_value()) { + result = result.requires_grad(maybe_requires_grad.value()); + } + auto maybe_memory_format = std::get<1>(tuple); + if (maybe_memory_format.has_value()) { + result = result.memory_format(maybe_memory_format.value()); + } + auto maybe_device = std::get<2>(tuple); + if (maybe_device.has_value()) { + result = result.device(maybe_device.value()); + } + auto maybe_dtype = std::get<3>(tuple); + if (maybe_dtype.has_value()) { + result = + result.dtype(caffe2::TypeMeta::fromScalarType(maybe_dtype.value())); + } + auto maybe_layout = std::get<4>(tuple); + if (maybe_layout.has_value()) { + result = result.layout(maybe_layout.value()); + } + auto maybe_pinned_memory = std::get<5>(tuple); + if (maybe_pinned_memory.has_value()) { + result = result.pinned_memory(maybe_pinned_memory.value()); + } + return result; +} + +template <> +struct IValuePacker { + static at::IValue pack(const at::TensorOptions& t) { + return pack_TensorOptions(t); + } + static at::TensorOptions unpack(const at::IValue& t) { + auto tuple = t.to(); + return unpack_TensorOptions(tuple); + } + static at::TypePtr packed_type() { + return at::TupleType::create( + {at::OptionalType::create(at::BoolType::get()), + at::OptionalType::create(at::MemoryFormatType::get()), + at::OptionalType::create(at::DeviceObjType::get()), + at::OptionalType::create(at::ScalarTypeType::get()), + at::OptionalType::create(at::LayoutType::get()), + at::OptionalType::create(at::BoolType::get())}); + } +}; + +template <> +struct IValuePacker { + static at::IValue pack(const TypeAndSize& t) { + auto tuple = std::make_tuple(t.sym_sizes, pack_TensorOptions(t.options)); + return tuple; + } + static TypeAndSize unpack(const at::IValue& t) { + auto tuple = + t.to, packed_tensoroptions_t>>(); + TypeAndSize result; + result.sym_sizes = std::get<0>(tuple); + result.options = unpack_TensorOptions(std::get<1>(tuple)); + return result; + } + static at::TypePtr packed_type() { + return at::TupleType::create( + {IValuePacker>::packed_type(), + IValuePacker::packed_type()}); + } +}; + +template +struct IValuePacker> { + static at::IValue pack(const std::optional& t) { + if (t.has_value()) { + return IValuePacker::pack(t.value()); + } else { + return std::nullopt; + } + } + static std::optional unpack(const at::IValue& t) { + if (t.isNone()) { + return std::nullopt; + } else { + return IValuePacker::unpack(t); + } + } + static at::TypePtr packed_type() { + return at::OptionalType::create(IValuePacker::packed_type()); + } +}; + +template +struct IValuePacker> { + static at::IValue pack(const std::vector& t) { + if constexpr (::std::is_constructible_v) { + return t; + } + if (t.empty()) { + auto lst = c10::impl::GenericList(at::AnyType::get()); + return lst; + } + auto type_ptr = IValuePacker::pack(t[0]).type(); + auto lst = c10::impl::GenericList(type_ptr); + for (const auto& elt : t) { + lst.emplace_back(IValuePacker::pack(elt)); + } + return lst; + } + static std::vector unpack(const at::IValue& t) { + if constexpr (::std::is_constructible_v) { + return t.to<::std::vector>(); + } + std::vector result; + auto lst = t.toList(); + for (const at::IValue& elt : lst) { + result.emplace_back(IValuePacker::unpack(elt)); + } + return result; + } + static at::TypePtr packed_type() { + return at::ListType::create(IValuePacker::packed_type()); + } +}; + +template +struct IValuePacker> { + static at::IValue pack(const c10::List& t) { + return IValuePacker>::pack(t.vec()); + } + static c10::List unpack(const at::IValue& t) { + return c10::List(IValuePacker>::unpack(t)); + } + static at::TypePtr packed_type() { + return IValuePacker>::packed_type(); + } +}; + +template +struct IValuePacker> { + static at::IValue pack(const std::array& t) { + std::vector result(t.begin(), t.end()); + return IValuePacker>::pack(result); + } + static std::array unpack(const at::IValue& t) { + std::array result; + auto packed = IValuePacker>::unpack(t); + for (size_t i = 0; i < packed.size(); i++) { + result[i] = packed[i]; + } + return result; + } + static at::TypePtr packed_type() { + return IValuePacker>::packed_type(); + } +}; + +template <> +struct IValuePacker { + static at::IValue pack(const at::TensorGeometry& t) { + auto tuple = std::make_tuple( + t.sym_sizes().vec(), t.sym_strides().vec(), t.sym_storage_offset()); + return tuple; + } + static at::TensorGeometry unpack(const at::IValue& t) { + auto tuple = t.to, + std::vector, + at::SymInt>>(); + return at::TensorGeometry( + std::get<0>(tuple), std::get<1>(tuple), std::get<2>(tuple)); + } + static at::TypePtr packed_type() { + return at::TupleType::create( + {IValuePacker>::packed_type(), + IValuePacker>::packed_type(), + at::SymIntType::get()}); + } +}; + +template <> +struct IValuePacker { + static at::IValue pack(const InputMetadata& t) { + TORCH_INTERNAL_ASSERT(!t.is_nested_tensor()); + auto tuple = std::make_tuple( + pack_TensorOptions(t.options()), + t.shape_as_dim_vector().vec(), + t.is_tensor_subclass()); + return tuple; + } + static InputMetadata unpack(const at::IValue& t) { + auto tuple = t.to< + std::tuple, bool>>(); + + return InputMetadata( + unpack_TensorOptions(std::get<0>(tuple)), + SymIntSmallVec(std::get<1>(tuple)), + std::get<2>(tuple), + false); + } + static at::TypePtr packed_type() { + return at::TupleType::create( + {IValuePacker::packed_type(), + IValuePacker>::packed_type(), + at::BoolType::get()}); + } +}; + +template +struct IValuePacker> { + static at::IValue pack(const at::OptionalArray& t) { + return IValuePacker>>::pack(t.list); + } + static at::OptionalArray unpack(const at::IValue& t) { + auto result = IValuePacker>>::unpack(t); + if (result.has_value()) { + return {result.value()}; + } else { + return {}; + } + } + static at::TypePtr packed_type() { + return IValuePacker>>::packed_type(); + } +}; + +// This is a helper struct for packing and unpacking multiple arguments into +// an ivalue_list. It leverages IValuePacker. +struct PackedArgs { + PackedArgs() = default; + + explicit PackedArgs(std::vector stack_) + : stack(std::move(stack_)) {} + + const std::vector& vec() const { + return stack; + } + + template + void pack(const T& t) { + stack.emplace_back(IValuePacker::pack(t)); + } + template + T unpack() { + return IValuePacker::unpack(std::move(stack[idx++])); + } + + void pack_saved_data(const ska::flat_hash_map& dct) { + std::vector keys; + std::vector values; + for (const auto& [key, value] : dct) { + keys.emplace_back(key); + values.emplace_back(value); + } + pack(keys); + for (const auto& value : values) { + pack(value); + } + } + + ska::flat_hash_map unpack_saved_data() { + ska::flat_hash_map dct; + auto keys = unpack>(); + for (const auto& key : keys) { + dct.insert({key, std::move(stack[idx++])}); + } + return dct; + } + + private: + std::vector stack; + int64_t idx = 0; +}; + +// This is a layer of indirection for calling methods on the Python +// AutogradCompilerInstance (referred to as the "py_compiler") from +// libtorch_cpu (where Python is not available). +// A PyCompilerInterfaceImpl in libtorch_python subclasses it and +// overrides the methods to do the actual calls back to Python. +struct TORCH_API PyCompilerInterface { + PyCompilerInterface() = default; + PyCompilerInterface(const PyCompilerInterface&) = delete; + PyCompilerInterface& operator=(const PyCompilerInterface&) = delete; + PyCompilerInterface(PyCompilerInterface&&) = delete; + PyCompilerInterface& operator=(PyCompilerInterface&&) = delete; + virtual ~PyCompilerInterface() = default; + + // Invokes py_compiler.bind_function(fn_name, fn) + virtual std::string bind_function( + PyObject* py_compiler, + const std::string& fn_name, + // NOLINTNEXTLINE(performance-unnecessary-value-param) + functional_apply_t fn, + // NOLINTNEXTLINE(performance-unnecessary-value-param) + std::vector packed_args_schema, + bool is_custom_function = false) { + TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); + } + + // Invokes py_compiler.method_name(fn_name, inputs, packed_args, + // output_metadata) + virtual variable_list call_function( + PyObject* py_compiler, + const char* method_name, + const std::string& fn_name, + const variable_list& inputs, + const ivalue_list& packed_args, + const c10::IValue& output_metadata) { + TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); + } + + virtual variable_list call_copy_slices_prologue( + PyObject* py_compiler, + const variable_list& inputs, + const at::TensorGeometry& base, + const at::TensorGeometry& view) { + TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); + } + virtual variable_list call_copy_slices_epilogue( + PyObject* py_compiler, + const std::vector& needs_input_grad, + const at::Tensor& result, + const variable_list& res, + const at::Tensor& grad_slice) { + TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); + } +}; + +TORCH_API const std::unique_ptr& getPyCompilerInterface(); +TORCH_API void setPyCompilerInterface( + std::unique_ptr&& impl); +TORCH_API void resetPyCompilerInterface(); + +// including torch/csrc/autograd/engine.h breaks BC by somehow introducing +// symbol resolution issues. Instead requiring downstream users to include +// engine.h to access collect_input_metadata, we provide it here (with a +// different name to avoid ambigous symbols...) +TORCH_API std::vector> get_input_metadata( + const edge_list& edges); + } // namespace torch::dynamo::autograd template <> diff --git a/torch/csrc/dynamo/python_compiled_autograd.cpp b/torch/csrc/dynamo/python_compiled_autograd.cpp index b58f62c3d2fcf..ccdf65e887d9d 100644 --- a/torch/csrc/dynamo/python_compiled_autograd.cpp +++ b/torch/csrc/dynamo/python_compiled_autograd.cpp @@ -52,6 +52,156 @@ at trace time. namespace torch::dynamo::autograd { using c10::SymInt; +// List[Optional[Tensor]] in Python can't be directly parsed into a +// List[Tensor], so we need to do this conversion manually. +static std::vector toTensorList( + const std::vector>& inputs) { + std::vector result; + result.reserve(inputs.size()); + for (const auto& inp : inputs) { + if (inp.has_value()) { + result.emplace_back(*inp); + } else { + result.emplace_back(); + } + } + return result; +} + +// Binds a function (that represents some backward computation) to Python. +// All of these functions have a common signature, which is +// (in C++) (vector, vector) -> vector +// (in Python) (List[Optional[Tensor]], *packed_args: IValue) -> +// List[Optional[Tensor]] +// +// The vector are the list of gradient Tensors, each of which may be +// undefined (in C++) which corresponds to None (in Python). +static std::string bind_function( + PyObject* py_compiler, + const std::string& fn_name, + functional_apply_t fn, + std::vector packed_args_schema, + bool is_custom_function) { + // This is the function that can be called from Python. + auto py_func = py::cpp_function( + [packed_args_schema = std::move(packed_args_schema), fn = std::move(fn)]( + std::vector>& inputs, + const py::args& py_args) -> py::object { + // py_args is a tuple of PyObject*. + // We need to reconstruct a vector to invoke `fn`. + // To do so, we use the packed_args_schema to convert each PyObject* + // to its corresponding C++ type that can be stored into IValue. + TORCH_INTERNAL_ASSERT(py_args.size() == packed_args_schema.size()); + std::vector args; + args.reserve(py_args.size()); + auto tuple_args = jit::tuple_slice(py_args); + for (uint64_t idx = 0; idx < packed_args_schema.size(); idx++) { + if (packed_args_schema[idx]->isSubtypeOf( + *at::ListType::ofTensors())) { + // List[Tensor] might have Nones, not handled in jit::toIValue + auto tmp = py::cast>>( + tuple_args[idx]); + args.emplace_back(toTensorList(tmp)); + } else { + args.emplace_back(jit::toIValue( + tuple_args[idx], packed_args_schema[idx], std::nullopt)); + } + } + // None in Python corresponds to undefined Tensor in C++ + auto inputs_ = toTensorList(inputs); + auto outputs = fn(inputs_, args); + return jit::toPyObject(at::IValue(outputs)); + }); + py::handle handle(py_compiler); + auto result = + handle.attr("bind_function")(fn_name, py_func, is_custom_function); + return result.cast(); +} + +// Invokes py_compiler.method_name(fn_name, inputs, packed_args, +// output_metadata) +static variable_list call_function( + PyObject* py_compiler, + const char* method_name, + const std::string& fn_name, + const variable_list& inputs, + const ivalue_list& packed_args, + const c10::IValue& output_metadata) { + // convert ivalue_list -> PyObject* + PyObject* py_packed_args = + PyTuple_New(static_cast(packed_args.size())); + for (const auto i : c10::irange(packed_args.size())) { + py::object obj = jit::toPyObject(packed_args[i]); + Py_INCREF(obj.ptr()); + PyTuple_SET_ITEM(py_packed_args, i, obj.ptr()); + } + + // call the corresponding method on the py_compiler + py::handle handle(py_compiler); + py::object stuff = handle.attr(method_name)( + fn_name, + inputs, + py::handle(py_packed_args), + jit::toPyObject(output_metadata)); + + // Convert the output from PyObject* to vector + auto tmp = py::cast>>(stuff); + return toTensorList(tmp); +} + +struct PyCompilerInterfaceImpl : PyCompilerInterface { + std::string bind_function( + PyObject* py_compiler, + const std::string& fn_name, + functional_apply_t fn, + std::vector packed_args_schema, + bool is_custom_function = false) override { + return torch::dynamo::autograd::bind_function( + py_compiler, + fn_name, + std::move(fn), + std::move(packed_args_schema), + is_custom_function); + } + variable_list call_function( + PyObject* py_compiler, + const char* method_name, + const std::string& fn_name, + const variable_list& inputs, + const ivalue_list& packed_args, + const c10::IValue& output_metadata) override { + return torch::dynamo::autograd::call_function( + py_compiler, + method_name, + fn_name, + inputs, + packed_args, + output_metadata); + } + variable_list call_copy_slices_prologue( + PyObject* py_compiler, + const variable_list& inputs, + const at::TensorGeometry& base, + const at::TensorGeometry& view) override { + py::handle handle(py_compiler); + py::object stuff = + handle.attr("call_copy_slices_prologue")(inputs, base, view); + return py::cast>(stuff); + } + variable_list call_copy_slices_epilogue( + PyObject* py_compiler, + const std::vector& needs_input_grad, + const at::Tensor& result, + const variable_list& res, + const at::Tensor& grad_slice) override { + py::handle handle(py_compiler); + py::object stuff = handle.attr("call_copy_slices_epilogue")( + needs_input_grad, result, res, grad_slice); + auto output = py::cast>>(stuff); + return toTensorList(output); + } +}; + static PyObject* wrap_int_list(const std::vector& inputs) { PyObject* pyinput = PyTuple_New(static_cast(inputs.size())); for (const auto i : c10::irange(inputs.size())) { @@ -88,6 +238,22 @@ static void check(bool result) { check(nullptr); } +static variable_list validate_outputs( + const variable_list& outputs, + const ivalue_list& args) { + auto r = PackedArgs(args); + auto value = r.unpack>>(); + auto new_outputs = outputs; + + torch::autograd::validate_outputs( + value, new_outputs, [&](const std::string& msg) { + std::ostringstream ss; + ss << "[Compiled Autograd Tracing:]" << msg; + return ss.str(); + }); + return new_outputs; +} + // snapshot of python verbose logging toggle static PyObject* python_verbose_logger = nullptr; @@ -498,6 +664,21 @@ static void set_ivalue_proxies( } } +static at::Tensor call_accumulate( + PyObject* py_compiler, + const at::Tensor& old_var, + const at::Tensor& new_var) { + if (!old_var.defined()) { + return new_var; + } + if (!new_var.defined()) { + return old_var; + } + py::handle handle(py_compiler); + py::object stuff = handle.attr("accumulate")(old_var, new_var); + return py::cast(stuff); +} + static TraceState call_begin_capture( PyObject* self, CacheNode& cache, @@ -657,6 +838,8 @@ static CacheNode* _compiled_autograd_impl( ClosingTHPObjectPtr py_compiler( check(PyObject_CallNoArgs((the_autograd_compiler)))); + setPyCompilerInterface(std::make_unique()); + TraceState state = call_begin_capture( py_compiler, *cache, compiler_call, output_edges.size()); InputBuffers input_buffers; @@ -723,16 +906,52 @@ static CacheNode* _compiled_autograd_impl( SwapSavedVariables saved(compiler_call, state, py_compiler.get(), call); variable_list outputs = call.node->apply_with_saved(inputs, saved); - saved.debug_asserts(); saved.before(call.node->next_edges()); - validate_outputs( - call.node->next_edges(), outputs, [&](const std::string& msg) { - std::ostringstream ss; - ss << "[Compiled Autograd Tracing: " << call.node->name() << "] " - << msg; - return ss.str(); - }); + + auto input_metadata = get_input_metadata(call.node->next_edges()); + TORCH_INTERNAL_ASSERT(input_metadata.size() == outputs.size()); + + // Lazily bind the `validate_outputs` function to Python. + static c10::once_flag flag; + c10::call_once(flag, [&]() { + auto schema = std::vector{IValuePacker< + std::vector>>::packed_type()}; + bind_function( + py_compiler.get(), + "validate_outputs", + validate_outputs, + schema, + false); + }); + + // Don't emit validate_outputs nodes that follow a CompiledBackward node. + // These nodes would otherwise prevent reordering of accumulate_grad + // nodes. + // + // Note that this will not cause correctness issues, because + // 1) AOTAutograd already coerces gradients to have the same metadata as + // the inputs. 2) the AOTAutograd graph already has the necessary + // aten::sum_to nodes in it (so it doesn't need to rely on + // validate_outputs to handle that). + // + // However, we may be dropping some (edge case) safety checks compared to + // eager: a backward that would have errored out in eager may not error + // out in compiled autograd (for example, if the user provided an + // incorrect number of gradients). + if (!call.node->is_aot_backward()) { + PackedArgs args; + args.pack(input_metadata); + ivalue_list input_metadata_state = std::move(args).vec(); + outputs = call_function( + py_compiler, + "validate_outputs", + "validate_outputs", + outputs, + input_metadata_state, + input_metadata_state[0]); + } + saved.after(call.node->next_edges()); saved.debug_asserts(); @@ -754,13 +973,14 @@ static CacheNode* _compiled_autograd_impl( auto& output = outputs[i]; const auto& next = call.node->next_edge(i); if (next.is_valid() && output.defined()) { - input_buffers.lookup(next.function.get()) - .add( - next.input_nr, std::move(output), std::nullopt, std::nullopt); + auto& buffer = input_buffers.lookup(next.function.get()); + buffer.buffer[next.input_nr] = call_accumulate( + py_compiler, buffer.buffer[next.input_nr], output); } } } + resetPyCompilerInterface(); PyObject* res = check(call_end_capture(py_compiler, state.outputs)); TORCH_CHECK(PyTuple_Check(res), "Expected end_capture to return tuple"); TORCH_CHECK(