Skip to content

Running backward multiple times with retain_graph=True fails if some object was saved for backward #2221

@kshitij12345

Description

@kshitij12345

This happens since 290a52e

Repro:

import torch
import thunder
from torch.testing import make_tensor

def foo(a, c):
    return a * c

a = make_tensor((2, 2), device="cpu", dtype=torch.float32, requires_grad=True)
c = 2.0

dynamic_jit = thunder.jit(foo, cache="symbolic values")
static_jit = thunder.jit(foo)

out = dynamic_jit(a, c)

thunder.last_traces(dynamic_jit)[-1].save_trace("fwd_trc.py")
thunder.last_backward_traces(dynamic_jit)[-1].save_trace("bwd_trc.py")

torch.autograd.backward(out, torch.rand_like(out), retain_graph=True)

# This fails.
#   File "thunder.backward_fn_2", line 19, in backward_fn
# ValueError: not enough values to unpack (expected 1, got 0)
torch.autograd.backward(out, torch.rand_like(out))

Probable cause is that forward trace passes list of saved objects (which is cleared after the first backward) while previously it used to be a tuple.

Forward Trace Before 290a52e

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a, c):
  # a: "cpu f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=2, static=CONSTRAINT.CONSTRAINABLE]]"
  # c: "float 2.0"

  # /opt/pytorch/lightning-thunder/test.py:22: 	    return a * c
  t34 = torch.mul(a, c)  # t34: "cpu f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=2, static=CONSTRAINT.CONSTRAINABLE]]"
    # t34 = ltorch.mul(a, c)  # t34: "cpu f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=2, static=CONSTRAINT.CONSTRAINABLE]]"
      # (i0, i1) = prims.shape(a)
      # (i0, i1) = prims.shape(a)
      # i30 = prims.eq(i1, 1)  # i30: "bool False"
      # i31 = prims.eq(i1, i1)  # i31: "bool True"
      # i32 = prims.eq(i0, 1)  # i32: "bool False"
      # i33 = prims.eq(i0, i0)  # i33: "bool True"
      # (i0, i1) = prims.shape(a)
      # (i0, i1) = prims.shape(a)
      # t34 = prims.mul(a, c)  # t34: "cpu f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=2, static=CONSTRAINT.CONSTRAINABLE]]"
  return {'output': (t34,), 'flat_args': [a, c], 'flat_output': (t34,)}, ((), (c,))

Forward Trace After 290a52e

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a, c):
  # a: "cpu f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=2, static=CONSTRAINT.CONSTRAINABLE]]"
  # c: "float 2.0"

  # /opt/pytorch/lightning-thunder/test.py:23: 	    return a * c
  t42 = torch.mul(a, c)  # t42: "cpu f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=2, static=CONSTRAINT.CONSTRAINABLE]]"
    # t42 = ltorch.mul(a, c)  # t42: "cpu f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=2, static=CONSTRAINT.CONSTRAINABLE]]"
      # (i0, i1) = prims.shape(a)
      # (i0, i1) = prims.shape(a)
      # i34 = prims.eq(i1, 1)  # i34: "bool False"
      # i39 = prims.eq(i1, i1)  # i39: "bool True"
      # i40 = prims.eq(i0, 1)  # i40: "bool False"
      # i41 = prims.eq(i0, i0)  # i41: "bool True"
      # (i0, i1) = prims.shape(a)
      # (i0, i1) = prims.shape(a)
      # t42 = prims.mul(a, c)  # t42: "cpu f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=2, static=CONSTRAINT.CONSTRAINABLE]]"
  return {'output': (t42,), 'flat_args': [a, c], 'flat_output': (t42,)}, ([], [c])

Metadata

Metadata

Assignees

Labels

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions