Hi! Thanks for the repo.
Would you mind clarifying why isn't the max_dist variable detached in the "margin" loss function? Wouldn't gradients flow through it the way it is at the moment? I've tried detaching it and the loss does behave differently.
max_dist = (dist * mask).max()
https://github.com/wzhouad/Contra-OOD/blob/2a1d63a61c8b03efdc27ca08b22f5fab2bc6001d/model.py#L42C17-L42C47
Thanks!