Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nan of smooothing_loss and large number of cross_correlation_loss #2

Open
wendy127green opened this issue Aug 29, 2019 · 17 comments
Open

Comments

@wendy127green
Copy link

Hello,
I used your loss part in my registration task. But I got a Nan of smooothing_loss and large number of cross_correlation_loss like 3.3388e+10. I wonder why this happend and any solutions?
Thanks

@Hsankesara
Copy link
Owner

Hey @elfprincess3, can you please give me more information about the problem?

@wendy127green
Copy link
Author

wendy127green commented Aug 29, 2019

I use the following codes in your program:
def cross_correlation_loss(I, J, n):
#I = I.permute(0, 3, 1, 2)
#J = J.permute(0, 3, 1, 2)
batch_size, channels, xdim, ydim = I.shape
I2 = torch.mul(I, I)
J2 = torch.mul(J, J)
IJ = torch.mul(I, J)
sum_filter = torch.ones((1, channels, n, n))
sum_filter = sum_filter.cuda()
I_sum = torch.conv2d(I, sum_filter, padding=1, stride=(1,1))
J_sum = torch.conv2d(J, sum_filter, padding=1 ,stride=(1,1))
I2_sum = torch.conv2d(I2, sum_filter, padding=1, stride=(1,1))
J2_sum = torch.conv2d(J2, sum_filter, padding=1, stride=(1,1))
IJ_sum = torch.conv2d(IJ, sum_filter, padding=1, stride=(1,1))
win_size = n**2
u_I = I_sum / win_size
u_J = J_sum / win_size
cross = IJ_sum - u_JI_sum - u_IJ_sum + u_Iu_Jwin_size
I_var = I2_sum - 2 * u_I * I_sum + u_Iu_Iwin_size
J_var = J2_sum - 2 * u_J * J_sum + u_Ju_Jwin_size
cc = crosscross / (I_varJ_var + np.finfo(float).eps)
#print("CC", torch.mean(cc), "np.finfo(float).eps", np.finfo(float).eps)
return torch.mean(cc)

def smooothing_loss(y_pred):
dy = torch.abs(y_pred[:, 1:, :, :] - y_pred[:, :-1, :, :])
dx = torch.abs(y_pred[:, :, 1:, :] - y_pred[:, :, :-1, :])
dx = torch.mul(dx, dx)
dy = torch.mul(dy, dy)
d = torch.mean(dx) + torch.mean(dy)
return d/2.0

def vox_morph_loss(y, ytrue, n=9, lamda=0.01):
cc = cross_correlation_loss(y, ytrue, n)
sm = smooothing_loss(y)
print("CC Loss", cc, "Gradient Loss", sm)
loss = -1.0 * cc + lamda * sm
return loss
y is the ouput of registration network, and y_true is the fixed image.
I got this:
image

But the MSEloss the samples that have abnormal cross correlation loss, the MSEloss are normal.

@John1231983
Copy link

Hi you should smooth on deformation field , not warp image.

@John1231983
Copy link

And i think the smooth loss is wrong implementation

@wendy127green
Copy link
Author

wendy127green commented Aug 29, 2019

So, 'y' in vox_morph_loss is the deformation field? how about y_true?
In your code "cc = cross_correlation_loss(y, ytrue, n)"
but voxel morph is an unsupervised method, I am a little confused

@John1231983
Copy link

No. y is warped/moved image. y true is fixed image. I think this project is wrong to reproduce result. You can check tf version. Smooth must be in displacement vector

@Hsankesara
Copy link
Owner

Hey @John1231983, yeah you are right that y is the warped image and y_true is the fixed image. And thanks for pointing out that I made a mistake in implementing smooth loss. It's bad from my side and I apologise for that. I'll correct it as soon as possible.

@Hsankesara
Copy link
Owner

@elfprincess3, It's bad choice of name and I'll change it in my next commit. I apologise for the trouble.

@Hsankesara
Copy link
Owner

Hsankesara commented Aug 29, 2019

Hey everyone, if you found another place where I may have made a mistake, kindly share it. I'll be grateful for your feedback and improve it as soon as possible.

@John1231983
Copy link

John1231983 commented Aug 29, 2019

@Hsankesara : This is unit test for smooth loss in 3D. Hope it help

https://colab.research.google.com/drive/1GJl1zWxTPF4KyHwiqK8elraz0xqb855R

For spatial transformation layer, could we simple use grid_sample() function in pytorch, instead of write it. Could you try it?

For NCC loss, you implemented correct but when test on tf, I found a big gap betwen tf and pytorch. I opened the discussion in https://discuss.pytorch.org/t/big-error-between-tensorflow-code-and-reproduce-in-pytorch/54635/2

If you find the solution why it happen, let me know?

@Hsankesara
Copy link
Owner

Thanks, @John1231983, I'll definitely update the code accordingly. Thanks for your help.

@wendy127green
Copy link
Author

@Hsankesara, that's all right, thanks!

@wendy127green
Copy link
Author

@ John1231983, thanks for your answer. I will check the tf version. I want to find a suitable loss for image registration, but it seems cross correlation have some problems in pytorch.

@John1231983
Copy link

Ok. Now both loss are close to tensorflow version. You csn use it

@Hsankesara
Copy link
Owner

Thanks, @John1231983, for improving it. I'll add them as soon as possible.

@JianHangChen
Copy link

Hi John. May I ask if there is any update for this version?

@domadaaaa
Copy link

Hi you should smooth on deformation field , not warp image.

HI, I wraped the deformation field, but got a negative loss, did you solved it or run successfully in 2d/3d image without abnormal results?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants