-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathutil.py
39 lines (30 loc) · 1 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# -*- coding: utf-8 -*-
import torch
import logging
import math
import torch.nn as nn
import numpy as np
alpha,beta=torch.Tensor([0.6]).cuda(), torch.Tensor([0.4]).cuda()
def get_logger(filepath, log_info):
logger = logging.getLogger(filepath)
logger.setLevel(logging.INFO)
fh = logging.FileHandler(filepath)
fh.setLevel(logging.INFO)
logger.addHandler(fh)
logger.info('-' * 30 + log_info + '-' * 30)
return logger
def log_and_print(logger, msg):
logger.info(msg)
print(msg)
def loss_function_v2(sigma, x, mu):
sigma = torch.Tensor(sigma).cuda()
x = torch.Tensor(x).cuda()
mu = torch.Tensor(mu).cuda()
MSE_loss = nn.MSELoss(reduction='sum')
rec_loss = alpha/(sigma)*MSE_loss(x, mu)
sup_loss = beta*torch.log(sigma)
return rec_loss+sup_loss
def loss_function(recon_x, x, mu):
MSE_loss = nn.MSELoss(reduction='sum')
reconstruction_loss = MSE_loss(recon_x, x)+MSE_loss(x, mu)
return reconstruction_loss