Skip to content

Commit

Permalink
准备加入realsrgan
Browse files Browse the repository at this point in the history
  • Loading branch information
181404010226 committed Aug 22, 2024
1 parent 9e29f58 commit d99e73a
Showing 1 changed file with 42 additions and 23 deletions.
65 changes: 42 additions & 23 deletions DualTrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,19 @@
from torchvision.models import vgg16
from loss.perceptual_similarity.perceptual_loss import PerceptualLoss
import kornia
import colour
import numpy as np
import random

# 设置随机种子
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -54,34 +65,42 @@
optimizer = optim.AdamW(model.parameters(), lr=0.001)

import torch.nn.functional as F
from kornia.losses import ssim_loss #psnr_loss
from kornia.losses import ssim_loss

def lab_loss(output, target):
# Calculate PSNR loss
# psnr = psnr_loss(output, target, max_val=255.0) # Assuming 8-bit color depth
# psnr_loss_value = 1.0 - psnr / 100.0 # Convert PSNR to a loss (0 to 1 range)

# 计算 Delta E 2000 颜色差异

# Separate channel losses
l_loss = F.mse_loss(output[:, 0], target[:, 0]) / (100.0 ** 2) # Normalize by square of range
a_loss = F.mse_loss(output[:, 1], target[:, 1]) / (255.0 ** 2) # Normalize by square of range
b_loss = F.mse_loss(output[:, 2], target[:, 2]) / (255.0 ** 2) # Normalize by square of range

# SSIM loss (applied on L channel)
# SSIM loss (applied on L channel only)
ssim = ssim_loss(output[:, 0].unsqueeze(1), target[:, 0].unsqueeze(1), window_size=11)

# Total Variation loss (you might want to normalize this too)
tv_loss = (torch.mean(torch.abs(output[:, :, :, :-1] - output[:, :, :, 1:])) +
torch.mean(torch.abs(output[:, :, :-1, :] - output[:, :, 1:, :]))) / 255.0

high_freq_output = output - F.avg_pool2d(output, kernel_size=3, stride=1, padding=1)
high_freq_target = target - F.avg_pool2d(target, kernel_size=3, stride=1, padding=1)
high_freq_loss = F.mse_loss(high_freq_output, high_freq_target)
# Separate TV loss for each channel
tv_loss_l = (torch.mean(torch.abs(output[:, 0, :, :-1] - output[:, 0, :, 1:])) +
torch.mean(torch.abs(output[:, 0, :-1, :] - output[:, 0, 1:, :]))) / 100.0 # L range is 0-100
tv_loss_a = (torch.mean(torch.abs(output[:, 1, :, :-1] - output[:, 1, :, 1:])) +
torch.mean(torch.abs(output[:, 1, :-1, :] - output[:, 1, 1:, :]))) / 255.0 # a range is -128 to 127
tv_loss_b = (torch.mean(torch.abs(output[:, 2, :, :-1] - output[:, 2, :, 1:])) +
torch.mean(torch.abs(output[:, 2, :-1, :] - output[:, 2, 1:, :]))) / 255.0 # b range is -128 to 127
tv_loss = tv_loss_l + 0.5 * (tv_loss_a + tv_loss_b) # Weighting channels differently

# Separate high frequency loss for each channel
high_freq_output_l = output[:, 0] - F.avg_pool2d(output[:, 0].unsqueeze(1), kernel_size=3, stride=1, padding=1).squeeze(1)
high_freq_target_l = target[:, 0] - F.avg_pool2d(target[:, 0].unsqueeze(1), kernel_size=3, stride=1, padding=1).squeeze(1)
high_freq_loss_l = F.mse_loss(high_freq_output_l, high_freq_target_l) / (100.0 ** 2)

high_freq_output_ab = output[:, 1:] - F.avg_pool2d(output[:, 1:], kernel_size=3, stride=1, padding=1)
high_freq_target_ab = target[:, 1:] - F.avg_pool2d(target[:, 1:], kernel_size=3, stride=1, padding=1)
high_freq_loss_ab = F.mse_loss(high_freq_output_ab, high_freq_target_ab) / (255.0 ** 2)

high_freq_loss = high_freq_loss_l + 0.5 * high_freq_loss_ab # Weighting channels differently

# Combine losses (adjust weights as needed)
total_loss = l_loss + 0.5 * (a_loss + b_loss) + 0.1 * ssim + 0.01 * tv_loss+ 0.1 * high_freq_loss
return total_loss, l_loss, a_loss, b_loss, ssim, tv_loss, high_freq_loss # Return individual losses for monitoring
# total_loss = 0.5 * psnr_loss_value + 0.2 * (l_loss + 0.5 * (a_loss + b_loss)) + 0.1 * ssim + 0.01 * tv_loss
# return total_loss, psnr_loss_value, l_loss, a_loss, b_loss, ssim, tv_loss
# 更新总损失计算
total_loss = 0.2 * (l_loss + 0.5 * (a_loss + b_loss)) + 0.1 * ssim + 0.01 * tv_loss + 0.1 * high_freq_loss
return total_loss, l_loss, a_loss, b_loss, ssim, tv_loss, high_freq_loss


# 训练函数
Expand All @@ -102,7 +121,7 @@ def train(epoch):
# Combine losses (you can adjust the weights)
# loss =loss_perceptual# + 0.5 * loss_mse
# LAB-specific loss
loss_lab, l_loss, a_loss, b_loss, ssim, tv_loss, high_freq_loss = lab_loss(output, data)
loss_lab, l_loss, a_loss, b_loss, ssim, tv_loss, high_freq_loss = lab_loss(output, data)

# Combine losses
loss = 0.3 * loss_perceptual + 0.7 * loss_lab
Expand All @@ -112,8 +131,8 @@ def train(epoch):

total_loss += loss.item()

print(f'Epoch {epoch}, Batch {batch_idx}, Total Loss: {loss.item():.4f}, '
f'DE: , L: {l_loss.item():.4f}, '
print(f'Epoch {epoch}, Batch {batch_idx}, Backward Loss: {loss.item():.4f}, '
f' L: {l_loss.item():.4f}, '
f'a: {a_loss.item():.4f}, b: {b_loss.item():.4f}, '
f'SSIM: {ssim.item():.4f}, TV: {tv_loss.item():.4f}, HF: {high_freq_loss.item():.4f}')
print(f'pLoss: {loss_perceptual.item():.4f}',
Expand All @@ -137,7 +156,7 @@ def save_image_comparison(epoch, data, output):
axes[1, i].set_title('Reconstructed')

plt.tight_layout()
plt.savefig(f'results/256*256LAB++stl10_epoch_{epoch}.png')
plt.savefig(f'results/256*256LAB_FULL_stl10_epoch_{epoch}.png')
plt.close()


Expand Down

0 comments on commit d99e73a

Please sign in to comment.