Skip to content

Commit

Permalink
Update mobile_hair.py
Browse files Browse the repository at this point in the history
- fix gradient loss with adding a range
- add gradient loss option
  • Loading branch information
davinnovation authored Mar 13, 2019
1 parent 587db78 commit 4146dfd
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions networks/mobile_hair.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,10 @@ def load_pretrained_model(self):
# hell baidu - https://github.com/marvis/pytorch-mobilenet

class HairMattingLoss(nn.modules.loss._Loss):
def __init__(self, ratio_of_Gradient=0.0):
def __init__(self, ratio_of_Gradient=0.0, add_gradient=False):
super(HairMattingLoss, self).__init__()
self.ratio_of_gradient = ratio_of_Gradient

self.add_gradient = add_gradient
self.bce_loss = nn.BCEWithLogitsLoss()

def forward(self, pred, true, image):
Expand All @@ -196,9 +196,15 @@ def forward(self, pred, true, image):
G_y = F.conv2d(pred, sobel_kernel_y)

G = torch.sqrt(torch.pow(G_x,2)+ torch.pow(G_y,2))

rang_grad = 1 - torch.pow(I_x*G_x + I_y*G_y,2)
rang_grad = range_grad if rang_grad > 0 else 0

loss2 = torch.sum(torch.mul(G, 1 - torch.pow(I_x*G_x + I_y*G_y,2)))/torch.sum(G) + 1e-6
loss2 = torch.sum(torch.mul(G, rang_grad))/torch.sum(G) + 1e-6

loss = self.bce_loss(pred, true)
if self.add_gradient:
loss = (1-self.ratio_of_gradient)*self.bce_loss(pred, true) + loss2*self.ratio_of_gradient
else:
loss = self.bce_loss(pred, true)

return loss
return loss

0 comments on commit 4146dfd

Please sign in to comment.