Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Trying to backward through the graph a second time #76

Open
daviddanan opened this issue Jul 4, 2023 · 0 comments
Open
Labels
bug Something isn't working

Comments

@daviddanan
Copy link

daviddanan commented Jul 4, 2023

Bug

Hello, i have discovered your library recently and i am trying to use it for my research. While trying to solve a constrained minimization problem, i encountered an error.
Here is the full error stack in a simplified case

Traceback (most recent call last):
  File "/home/david.danan/psinns/Tests/ReportIssue.py", line 137, in <module>
    EnergySolver(nodes=gridNodes,
  File "/home/david.danan/psinns/Tests/ReportIssue.py", line 112, in EnergySolver
    coop.step(loss.closure,predictedField,positions)
  File "/opt/miniconda/envs/FEMINNSEnv/lib/python3.9/site-packages/cooper/constrained_optimizer.py", line 222, in step
    lagrangian = self.formulation.composite_objective(
  File "/opt/miniconda/envs/FEMINNSEnv/lib/python3.9/site-packages/cooper/lagrangian_formulation.py", line 232, in composite_objective
    cmp_state = closure(*closure_args, **closure_kwargs)
  File "/home/david.danan/psinns/Tests/ReportIssue.py", line 84, in closure
    duxdxyz = grad(displacement[:, 0].unsqueeze(1), position, torch.ones(position.size()[0], 1, device=device), create_graph=True, retain_graph=True)[0]
  File "/opt/miniconda/envs/FEMINNSEnv/lib/python3.9/site-packages/torch/autograd/__init__.py", line 300, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

As far as i can understand, the case where the loss function requires an explicit call to autograd to compute its value is not handled for now.

I had tried to redefine the behaviour of cooper.LagrangianFormulation by inheriting from it and adding an option to the backward call

class LagrangianFormulationConserveGraph(cooper.LagrangianFormulation):
    @no_type_check
    def _populate_gradients(
        self, lagrangian: torch.Tensor, ignore_primal: bool = False
    ):

        if ignore_primal and self.cmp.is_constrained:
            pass
        else:
            lagrangian.backward(retain_graph=True)

        if self.cmp.is_constrained:
            for violation_for_update in self.state_update:
                dual_vars = [_ for _ in self.state() if _ is not None]
                violation_for_update.backward(inputs=dual_vars)

The error did disappear but the optimization was stucked (neither the loss nor the constraints changes throught the iterations). Therefore, i guess it was not the correct way to do it.

Steps

The full script is enclosed below and should be enough to reproduce this bug.

Expected behavior

Be able to run the script and seing actual change throught the iteration process regarding the loss and constraints

Environment

For cooper, i followed the instruction on the Readme (pip)
cooper: 0.1.dev8+geae7c5a
All the other dependencies where installed within a conda environment:
-python=3.9.16
-numpy=1.22.3
-torch: 1.13.1

Context

Since the whole script has some external dependencies not related to this issue, i share below the problematic part mentionned in the error stack in a simplified shorter version.

In a nutshell, here is what i am trying to do:
-i had an input field that i know (input) and another that i don't (output)
-both fields live on a grid ( 3 values per nodes in both cases)
-Using a neural network, i would like to describe the relationship between these 2 fields
-The loss function denotes a property of the output field, the optimal solution should minimize this quantity
-The constraints enforce the value of the field on a specific part of the grid (here, 0).

The only explicit dependencies within the script are numpy, cooper and torch.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np

import torch
from torch.autograd import grad
import torch.nn as nn

import cooper
from cooper.problem import CMPState, Formulation

device = torch.device('cpu')
if torch.cuda.is_available():
    print("CUDA is available, running on GPU")
    device = torch.device('cuda')
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
    print("CUDA not available, running on CPU")

class TorchStraightForwardBuilder():
    def __init__(self,
                 operations):
        self.operations=operations

    def BuildModel(self):
        seqOperations=self.BuildSeqOperationsFromBlockOperations()
        return TorchClassicalStraightForwardNetwork(operations=seqOperations)

    def BuildSeqOperationsFromBlockOperations(self):
        seqOperations=[self.BuildNNComponent(nnComponent) for nnComponent in self.operations]
        return seqOperations

    def BuildNNComponent(self,nnComponent:tuple):
        brickName,params=nnComponent
        if params is None:
            return CreateTorchComponent(brickName)
        return CreateTorchComponent(brickName,params)


class TorchClassicalStraightForwardNetwork(nn.Module):
    def __init__(self,operations:list):
        super(TorchClassicalStraightForwardNetwork, self).__init__()
        self.allComponents=operations
        self.torchComponents = nn.ModuleList([op for op in operations if isinstance(op,nn.Module)])

    def forward(self, x):
        inputVal=x
        for component in self.allComponents:
            output=component(inputVal)
            inputVal=output
        return output

def CreateTorchComponent(name,ops=None):
    if ops==None:
        if name=="ReLU":
            return nn.ReLU()
        else:
            raise Exception("Activation function not recognized")
    else:
        if name=="Linear":
            return nn.Linear(**ops)
        else:
            raise Exception("Layer not recognized")


def trapz(y,dx,axis=-1):
    nd = y.ndimension()
    slice1 = [slice(None)] * nd
    slice2 = [slice(None)] * nd
    slice1[axis] = slice(1, None)
    slice2[axis] = slice(None, -1)
    integralValue = torch.sum(dx * (y[tuple(slice1)] + y[tuple(slice2)]) / 2.0, axis)
    return integralValue

class MinimizationEnergyProblem(cooper.ConstrainedMinimizationProblem):
    def __init__(self,
                 constraintCondition,
                 integrationDomain):
        super().__init__(is_constrained=True)
        self.constraintIndices,self.constraintValues=constraintCondition
        self.integrationDomain=integrationDomain

    def closure(self, displacement, position):
        duxdxyz = grad(displacement[:, 0].unsqueeze(1), position, torch.ones(position.size()[0], 1, device=device), create_graph=True, retain_graph=True)[0]
        internalEnergy = duxdxyz[:, 0].unsqueeze(1) + 1
        dx,dy,dz=self.integrationDomain["spacing"]
        nx,ny,nz=self.integrationDomain["nbNodes"]
        integrand = internalEnergy.reshape(nx,ny,nz)
        internalEnergyIntegrated=trapz(trapz(trapz(integrand, dx=dz), dx=dy), dx=dx)
        eq_defect=torch.flatten(displacement[self.constraintIndices] - self.constraintValues)
        return cooper.CMPState(loss=internalEnergyIntegrated, eq_defect=eq_defect)

def EnergySolver(nodes,model,constraintCondition,integrationDomain):
    modelFC=TorchStraightForwardBuilder(model)
    modelOutput=modelFC.BuildModel().to(device)
    loss = MinimizationEnergyProblem(constraintCondition=constraintCondition,
                              integrationDomain=integrationDomain)

    formulation = cooper.LagrangianFormulation(loss)#LagrangianFormulationConserveGraph(loss)
    positions=torch.from_numpy(nodes).to(device).requires_grad_(True)
    predictedField = modelOutput(positions)

    primal_optimizer = cooper.optim.ExtraSGD(modelOutput.parameters(), lr=3e-3, momentum=0.5)
    dual_optimizer = cooper.optim.partial_optimizer(cooper.optim.ExtraSGD, lr=9e-3, momentum=0.5)

    coop = cooper.ConstrainedOptimizer(formulation, primal_optimizer, dual_optimizer)
    state_history = cooper.StateLogger(save_metrics=["loss", "eq_defect", "eq_multipliers"])
    for iter_num in range(10):
        coop.zero_grad()
        lagrangian = formulation.composite_objective(loss.closure,predictedField,positions)
        formulation.custom_backward(lagrangian)
        coop.step(loss.closure,predictedField,positions)

if __name__ == '__main__':
    model=[
        ("Linear",{"in_features":3, "out_features":50}),
        ("ReLU",None),
        ("Linear",{"in_features":50, "out_features":50}),
        ("ReLU",None),
        ("Linear",{"in_features":50, "out_features":3})
          ]

    nNodesX,nNodesY,nNodesZ=10,20,30
    x = np.linspace(0, 1, nNodesX).astype(np.float32)
    y = np.linspace(0, 1, nNodesY).astype(np.float32)
    z = np.linspace(0, 1, nNodesZ).astype(np.float32)
    x,y,z=np.meshgrid(x,y,z)
    x,y,z=x.flatten(),y.flatten(),z.flatten()
    gridNodes=np.vstack((x,y,z)).transpose()

    constraintIndices=np.where(x==0)[0]
    constraintCondition=np.where(x==0)[0],torch.zeros((len(constraintIndices),3))
    
    hSpacings=1/(nNodesX-1),1/(nNodesY-1),1/(nNodesZ-1)
    integrationDomain={"spacing":hSpacings,"nbNodes":(nNodesX,nNodesY,nNodesZ)}

    EnergySolver(nodes=gridNodes,
                 model=model,
                 constraintCondition=constraintCondition,
                 integrationDomain=integrationDomain)
    
@daviddanan daviddanan added the bug Something isn't working label Jul 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant