From d99e73ae1f372d7cafd8417aef2aa4bb81de8aec Mon Sep 17 00:00:00 2001 From: xiyuren <761346811@qq.com> Date: Thu, 22 Aug 2024 14:56:05 +0800 Subject: [PATCH] =?UTF-8?q?=E5=87=86=E5=A4=87=E5=8A=A0=E5=85=A5realsrgan?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- DualTrain.py | 65 +++++++++++++++++++++++++++++++++------------------- 1 file changed, 42 insertions(+), 23 deletions(-) diff --git a/DualTrain.py b/DualTrain.py index 512b59d..0f8fe46 100644 --- a/DualTrain.py +++ b/DualTrain.py @@ -16,8 +16,19 @@ from torchvision.models import vgg16 from loss.perceptual_similarity.perceptual_loss import PerceptualLoss import kornia -import colour import numpy as np +import random + +# 设置随机种子 +seed = 42 +torch.manual_seed(seed) +np.random.seed(seed) +random.seed(seed) +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False # 设置设备 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -54,34 +65,42 @@ optimizer = optim.AdamW(model.parameters(), lr=0.001) import torch.nn.functional as F -from kornia.losses import ssim_loss #psnr_loss +from kornia.losses import ssim_loss def lab_loss(output, target): - # Calculate PSNR loss - # psnr = psnr_loss(output, target, max_val=255.0) # Assuming 8-bit color depth - # psnr_loss_value = 1.0 - psnr / 100.0 # Convert PSNR to a loss (0 to 1 range) - + # 计算 Delta E 2000 颜色差异 + # Separate channel losses l_loss = F.mse_loss(output[:, 0], target[:, 0]) / (100.0 ** 2) # Normalize by square of range a_loss = F.mse_loss(output[:, 1], target[:, 1]) / (255.0 ** 2) # Normalize by square of range b_loss = F.mse_loss(output[:, 2], target[:, 2]) / (255.0 ** 2) # Normalize by square of range - # SSIM loss (applied on L channel) + # SSIM loss (applied on L channel only) ssim = ssim_loss(output[:, 0].unsqueeze(1), target[:, 0].unsqueeze(1), window_size=11) - # Total Variation loss (you might want to normalize this too) - tv_loss = (torch.mean(torch.abs(output[:, :, :, :-1] - output[:, :, :, 1:])) + - torch.mean(torch.abs(output[:, :, :-1, :] - output[:, :, 1:, :]))) / 255.0 - - high_freq_output = output - F.avg_pool2d(output, kernel_size=3, stride=1, padding=1) - high_freq_target = target - F.avg_pool2d(target, kernel_size=3, stride=1, padding=1) - high_freq_loss = F.mse_loss(high_freq_output, high_freq_target) + # Separate TV loss for each channel + tv_loss_l = (torch.mean(torch.abs(output[:, 0, :, :-1] - output[:, 0, :, 1:])) + + torch.mean(torch.abs(output[:, 0, :-1, :] - output[:, 0, 1:, :]))) / 100.0 # L range is 0-100 + tv_loss_a = (torch.mean(torch.abs(output[:, 1, :, :-1] - output[:, 1, :, 1:])) + + torch.mean(torch.abs(output[:, 1, :-1, :] - output[:, 1, 1:, :]))) / 255.0 # a range is -128 to 127 + tv_loss_b = (torch.mean(torch.abs(output[:, 2, :, :-1] - output[:, 2, :, 1:])) + + torch.mean(torch.abs(output[:, 2, :-1, :] - output[:, 2, 1:, :]))) / 255.0 # b range is -128 to 127 + tv_loss = tv_loss_l + 0.5 * (tv_loss_a + tv_loss_b) # Weighting channels differently + + # Separate high frequency loss for each channel + high_freq_output_l = output[:, 0] - F.avg_pool2d(output[:, 0].unsqueeze(1), kernel_size=3, stride=1, padding=1).squeeze(1) + high_freq_target_l = target[:, 0] - F.avg_pool2d(target[:, 0].unsqueeze(1), kernel_size=3, stride=1, padding=1).squeeze(1) + high_freq_loss_l = F.mse_loss(high_freq_output_l, high_freq_target_l) / (100.0 ** 2) + + high_freq_output_ab = output[:, 1:] - F.avg_pool2d(output[:, 1:], kernel_size=3, stride=1, padding=1) + high_freq_target_ab = target[:, 1:] - F.avg_pool2d(target[:, 1:], kernel_size=3, stride=1, padding=1) + high_freq_loss_ab = F.mse_loss(high_freq_output_ab, high_freq_target_ab) / (255.0 ** 2) + + high_freq_loss = high_freq_loss_l + 0.5 * high_freq_loss_ab # Weighting channels differently - # Combine losses (adjust weights as needed) - total_loss = l_loss + 0.5 * (a_loss + b_loss) + 0.1 * ssim + 0.01 * tv_loss+ 0.1 * high_freq_loss - return total_loss, l_loss, a_loss, b_loss, ssim, tv_loss, high_freq_loss # Return individual losses for monitoring - # total_loss = 0.5 * psnr_loss_value + 0.2 * (l_loss + 0.5 * (a_loss + b_loss)) + 0.1 * ssim + 0.01 * tv_loss - # return total_loss, psnr_loss_value, l_loss, a_loss, b_loss, ssim, tv_loss + # 更新总损失计算 + total_loss = 0.2 * (l_loss + 0.5 * (a_loss + b_loss)) + 0.1 * ssim + 0.01 * tv_loss + 0.1 * high_freq_loss + return total_loss, l_loss, a_loss, b_loss, ssim, tv_loss, high_freq_loss # 训练函数 @@ -102,7 +121,7 @@ def train(epoch): # Combine losses (you can adjust the weights) # loss =loss_perceptual# + 0.5 * loss_mse # LAB-specific loss - loss_lab, l_loss, a_loss, b_loss, ssim, tv_loss, high_freq_loss = lab_loss(output, data) + loss_lab, l_loss, a_loss, b_loss, ssim, tv_loss, high_freq_loss = lab_loss(output, data) # Combine losses loss = 0.3 * loss_perceptual + 0.7 * loss_lab @@ -112,8 +131,8 @@ def train(epoch): total_loss += loss.item() - print(f'Epoch {epoch}, Batch {batch_idx}, Total Loss: {loss.item():.4f}, ' - f'DE: , L: {l_loss.item():.4f}, ' + print(f'Epoch {epoch}, Batch {batch_idx}, Backward Loss: {loss.item():.4f}, ' + f' L: {l_loss.item():.4f}, ' f'a: {a_loss.item():.4f}, b: {b_loss.item():.4f}, ' f'SSIM: {ssim.item():.4f}, TV: {tv_loss.item():.4f}, HF: {high_freq_loss.item():.4f}') print(f'pLoss: {loss_perceptual.item():.4f}', @@ -137,7 +156,7 @@ def save_image_comparison(epoch, data, output): axes[1, i].set_title('Reconstructed') plt.tight_layout() - plt.savefig(f'results/256*256LAB++stl10_epoch_{epoch}.png') + plt.savefig(f'results/256*256LAB_FULL_stl10_epoch_{epoch}.png') plt.close()