-
Notifications
You must be signed in to change notification settings - Fork 104
Labels
Description
This happens since 290a52e
Repro:
import torch
import thunder
from torch.testing import make_tensor
def foo(a, c):
return a * c
a = make_tensor((2, 2), device="cpu", dtype=torch.float32, requires_grad=True)
c = 2.0
dynamic_jit = thunder.jit(foo, cache="symbolic values")
static_jit = thunder.jit(foo)
out = dynamic_jit(a, c)
thunder.last_traces(dynamic_jit)[-1].save_trace("fwd_trc.py")
thunder.last_backward_traces(dynamic_jit)[-1].save_trace("bwd_trc.py")
torch.autograd.backward(out, torch.rand_like(out), retain_graph=True)
# This fails.
# File "thunder.backward_fn_2", line 19, in backward_fn
# ValueError: not enough values to unpack (expected 1, got 0)
torch.autograd.backward(out, torch.rand_like(out))
Probable cause is that forward trace passes list of saved objects (which is cleared after the first backward) while previously it used to be a tuple.
Forward Trace Before 290a52e
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(a, c):
# a: "cpu f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=2, static=CONSTRAINT.CONSTRAINABLE]]"
# c: "float 2.0"
# /opt/pytorch/lightning-thunder/test.py:22: return a * c
t34 = torch.mul(a, c) # t34: "cpu f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=2, static=CONSTRAINT.CONSTRAINABLE]]"
# t34 = ltorch.mul(a, c) # t34: "cpu f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=2, static=CONSTRAINT.CONSTRAINABLE]]"
# (i0, i1) = prims.shape(a)
# (i0, i1) = prims.shape(a)
# i30 = prims.eq(i1, 1) # i30: "bool False"
# i31 = prims.eq(i1, i1) # i31: "bool True"
# i32 = prims.eq(i0, 1) # i32: "bool False"
# i33 = prims.eq(i0, i0) # i33: "bool True"
# (i0, i1) = prims.shape(a)
# (i0, i1) = prims.shape(a)
# t34 = prims.mul(a, c) # t34: "cpu f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=2, static=CONSTRAINT.CONSTRAINABLE]]"
return {'output': (t34,), 'flat_args': [a, c], 'flat_output': (t34,)}, ((), (c,))
Forward Trace After 290a52e
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(a, c):
# a: "cpu f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=2, static=CONSTRAINT.CONSTRAINABLE]]"
# c: "float 2.0"
# /opt/pytorch/lightning-thunder/test.py:23: return a * c
t42 = torch.mul(a, c) # t42: "cpu f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=2, static=CONSTRAINT.CONSTRAINABLE]]"
# t42 = ltorch.mul(a, c) # t42: "cpu f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=2, static=CONSTRAINT.CONSTRAINABLE]]"
# (i0, i1) = prims.shape(a)
# (i0, i1) = prims.shape(a)
# i34 = prims.eq(i1, 1) # i34: "bool False"
# i39 = prims.eq(i1, i1) # i39: "bool True"
# i40 = prims.eq(i0, 1) # i40: "bool False"
# i41 = prims.eq(i0, i0) # i41: "bool True"
# (i0, i1) = prims.shape(a)
# (i0, i1) = prims.shape(a)
# t42 = prims.mul(a, c) # t42: "cpu f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=2, static=CONSTRAINT.CONSTRAINABLE]]"
return {'output': (t42,), 'flat_args': [a, c], 'flat_output': (t42,)}, ([], [c])