diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py index 36306dbc8..66f1fc794 100644 --- a/functorch/_src/aot_autograd.py +++ b/functorch/_src/aot_autograd.py @@ -116,6 +116,26 @@ def new_zeros(inp, size, dtype=None, layout=None, device=None, pin_memory=None): def new_full(inp, size, value, dtype=None, layout=None, device=None, pin_memory=None): return torch.full(size, value, dtype=inp.dtype, device=inp.device) +import torch.fx as fx +import typing +class ListCodeGen(fx.CodeGen): + def gen_fn_def(self, free_vars, maybe_return_annotation): + lst_unpack = f""" +def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}: + {', '.join(free_vars)} = args_list + args_list.clear() + """ + return lst_unpack + + def additional_globals(self): + return [('List', typing.List)] + + def process_inputs(self, *inputs): + assert(len(inputs) == 1) + return inputs[0] + +def get_memory(): + print(torch.cuda.max_memory_allocated()/1e9) def create_aot_autograd_function( flat_fn, fw_compiler, bw_compiler, partition_fn, decompositions, grad_state @@ -168,18 +188,22 @@ def forward(ctx, *flat_tensor_args): fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs] + bw_module.graph.set_codegen(ListCodeGen()) + bw_module.recompile() compiled_bw = bw_compiler(bw_module, bw_args) else: fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) - ctx.save_for_backward(*fw_outs[num_outs:]) + # No way of clearing ctx.saved_tensors right now afaik + ctx.saved_values = fw_outs[num_outs:] return tuple(fw_outs[0:num_outs]) @staticmethod @disable_torchdynamo def backward(ctx, *flat_args): contiguous_args = [t.contiguous() for t in flat_args] - # contiguous_args = [t for t in flat_args] - out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args)) + flat_args = list(ctx.saved_values) + list(contiguous_args) + ctx.saved_values = None + out = normalize_as_list(compiled_bw(flat_args)) return tuple(out) return CompiledFunction