-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c723c4b
commit 50c8a71
Showing
22 changed files
with
1,895 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
# class Residual(nn.Module): | ||
# def __init__(self, fn): | ||
# super().__init__() | ||
# self.fn = fn | ||
|
||
# def forward(self, x): | ||
# return self.fn(x) + x | ||
|
||
|
||
# class ConvBlock(nn.Module): | ||
# def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1): | ||
# super(ConvBlock, self).__init__() | ||
# self.conv = nn.Sequential( | ||
# Residual(nn.Sequential( | ||
# nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding), | ||
# nn.GELU(), | ||
# nn.BatchNorm2d(in_channels) | ||
# )), | ||
# nn.Conv2d(in_channels, out_channels, kernel_size=1), | ||
# nn.GELU(), | ||
# nn.BatchNorm2d(out_channels) | ||
# ) | ||
|
||
# def forward(self, x): | ||
# return self.conv(x) | ||
|
||
class ConvBlock(nn.Module): | ||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1): | ||
super(ConvBlock, self).__init__() | ||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) | ||
self.bn = nn.BatchNorm2d(out_channels) | ||
self.relu = nn.ReLU(inplace=True) | ||
|
||
def forward(self, x): | ||
return self.relu(self.bn(self.conv(x))) | ||
|
||
class UpBlock(nn.Module): | ||
def __init__(self, in_channels, out_channels,kernel_size=3, stride=1, padding=1): | ||
super(UpBlock, self).__init__() | ||
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) | ||
# self.upsample = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=False) | ||
self.conv = ConvBlock(in_channels, out_channels,kernel_size=kernel_size, stride=stride, padding=padding) | ||
|
||
def forward(self, x): | ||
x = self.upsample(x) | ||
return self.conv(x) | ||
|
||
# class DeconvBlock(nn.Module): | ||
# def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, output_padding=0): | ||
# super(DeconvBlock, self).__init__() | ||
# self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, output_padding) | ||
# self.bn = nn.BatchNorm2d(out_channels) | ||
# self.relu = nn.ReLU(inplace=True) | ||
|
||
# def forward(self, x): | ||
# return self.relu(self.bn(self.deconv(x))) | ||
|
||
class HourglassModel(nn.Module): | ||
def __init__(self, input_channels=3, latent_dim=128): | ||
super(HourglassModel, self).__init__() | ||
|
||
# Encoder (正卷积部分) | ||
self.encoder = nn.Sequential( | ||
ConvBlock(input_channels, 64), | ||
ConvBlock(64, 64), | ||
nn.MaxPool2d(2), | ||
ConvBlock(64, 128), | ||
ConvBlock(128, 128), | ||
nn.MaxPool2d(2), | ||
ConvBlock(128, 256), | ||
ConvBlock(256, 256), | ||
nn.MaxPool2d(2), | ||
ConvBlock(256, 512), | ||
ConvBlock(512, 512), | ||
nn.MaxPool2d(2), | ||
ConvBlock(512, latent_dim) | ||
) | ||
|
||
# Decoder | ||
self.decoder = nn.Sequential( | ||
UpBlock(latent_dim, 512), | ||
UpBlock(512, 256), | ||
UpBlock(256, 128), | ||
UpBlock(128, 64), | ||
nn.Conv2d(64, input_channels, kernel_size=3, stride=1, padding=1), | ||
nn.Tanh() | ||
) | ||
|
||
# # Decoder (反卷积部分) | ||
# self.decoder = nn.Sequential( | ||
# DeconvBlock(latent_dim, 512, kernel_size=3, stride=2, padding=1, output_padding=1), | ||
# DeconvBlock(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1), | ||
# DeconvBlock(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1), | ||
# DeconvBlock(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), | ||
# nn.Conv2d(64, input_channels, kernel_size=3, stride=1, padding=1), | ||
# nn.Tanh() | ||
# ) | ||
|
||
def forward(self, x): | ||
latent = self.encoder(x) | ||
output = self.decoder(latent) | ||
return output | ||
|
||
def encode(self, x): | ||
return self.encoder(x) | ||
|
||
def decode(self, latent): | ||
return self.decoder(latent) | ||
|
||
# 测试模型 | ||
if __name__ == "__main__": | ||
# 创建一个示例输入 | ||
batch_size = 1 | ||
channels = 3 | ||
height, width = 32, 32 | ||
x = torch.randn(batch_size, channels, height, width) | ||
|
||
# 初始化模型 | ||
model = HourglassModel() | ||
|
||
# 前向传播 | ||
output = model(x) | ||
|
||
print(model.encode(x).shape) | ||
print(f"Input shape: {x.shape}") | ||
print(f"Output shape: {output.shape}") | ||
# print(f"Model architecture:\n{model}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
from torch.utils.data import DataLoader | ||
from torchvision import datasets, transforms | ||
import os | ||
from tqdm import tqdm | ||
import matplotlib.pyplot as plt | ||
from DualConvMixer import HourglassModel | ||
from torchvision.models import vgg16 | ||
from torchvision.transforms.functional import to_pil_image | ||
from torch.optim.lr_scheduler import OneCycleLR | ||
from network.AutoEncoder import AutoEncoder | ||
import numpy as np | ||
import torch.nn.functional as F | ||
from torchvision.models import vgg16 | ||
from loss.perceptual_similarity.perceptual_loss import PerceptualLoss | ||
import kornia | ||
import colour | ||
import numpy as np | ||
|
||
# 设置设备 | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
# 创建保存结果的文件夹 | ||
os.makedirs("results", exist_ok=True) | ||
os.makedirs("checkpoints", exist_ok=True) | ||
|
||
# 数据预处理 | ||
|
||
transform = transforms.Compose([ | ||
transforms.Resize((256, 256)), | ||
transforms.ToTensor(), | ||
transforms.Lambda(lambda x: kornia.color.rgb_to_lab(x.unsqueeze(0)).squeeze(0)) | ||
]) | ||
|
||
batch_size = 128 | ||
# 加载STL10数据集 | ||
train_dataset = datasets.STL10(root='./STL10Data', split='train+unlabeled', download=True, transform=transform) | ||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8) | ||
|
||
# 加载CIFAR-10数据集 | ||
# train_dataset = datasets.CIFAR10(root='./CIFAR10RawData', train=True, download=True, transform=transform) | ||
# train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2) | ||
|
||
# 初始化模型 | ||
# model = HourglassModel().to(device) | ||
model = AutoEncoder(image_dims=(3, 256, 256), batch_size=batch_size, | ||
C=16,activation='leaky_relu').to(device) | ||
|
||
perceptual_loss = PerceptualLoss(model='net-lin', net='alex', | ||
colorspace='Lab', use_gpu=torch.cuda.is_available()).to(device) | ||
# 优化器 | ||
optimizer = optim.AdamW(model.parameters(), lr=0.001) | ||
|
||
import torch.nn.functional as F | ||
from kornia.losses import ssim_loss #psnr_loss | ||
|
||
def lab_loss(output, target): | ||
# Calculate PSNR loss | ||
# psnr = psnr_loss(output, target, max_val=255.0) # Assuming 8-bit color depth | ||
# psnr_loss_value = 1.0 - psnr / 100.0 # Convert PSNR to a loss (0 to 1 range) | ||
|
||
# Separate channel losses | ||
l_loss = F.mse_loss(output[:, 0], target[:, 0]) / (100.0 ** 2) # Normalize by square of range | ||
a_loss = F.mse_loss(output[:, 1], target[:, 1]) / (255.0 ** 2) # Normalize by square of range | ||
b_loss = F.mse_loss(output[:, 2], target[:, 2]) / (255.0 ** 2) # Normalize by square of range | ||
|
||
# SSIM loss (applied on L channel) | ||
ssim = ssim_loss(output[:, 0].unsqueeze(1), target[:, 0].unsqueeze(1), window_size=11) | ||
|
||
# Total Variation loss (you might want to normalize this too) | ||
tv_loss = (torch.mean(torch.abs(output[:, :, :, :-1] - output[:, :, :, 1:])) + | ||
torch.mean(torch.abs(output[:, :, :-1, :] - output[:, :, 1:, :]))) / 255.0 | ||
|
||
high_freq_output = output - F.avg_pool2d(output, kernel_size=3, stride=1, padding=1) | ||
high_freq_target = target - F.avg_pool2d(target, kernel_size=3, stride=1, padding=1) | ||
high_freq_loss = F.mse_loss(high_freq_output, high_freq_target) | ||
|
||
# Combine losses (adjust weights as needed) | ||
total_loss = l_loss + 0.5 * (a_loss + b_loss) + 0.1 * ssim + 0.01 * tv_loss+ 0.1 * high_freq_loss | ||
return total_loss, l_loss, a_loss, b_loss, ssim, tv_loss, high_freq_loss # Return individual losses for monitoring | ||
# total_loss = 0.5 * psnr_loss_value + 0.2 * (l_loss + 0.5 * (a_loss + b_loss)) + 0.1 * ssim + 0.01 * tv_loss | ||
# return total_loss, psnr_loss_value, l_loss, a_loss, b_loss, ssim, tv_loss | ||
|
||
|
||
# 训练函数 | ||
def train(epoch): | ||
model.train() | ||
total_loss = 0 | ||
for batch_idx, (data, _) in enumerate(tqdm(train_loader)): | ||
data = data.to(device) | ||
optimizer.zero_grad() | ||
output = model(data) | ||
|
||
# Calculate perceptual loss and take the mean to get a scalar | ||
loss_perceptual = perceptual_loss(output, data).mean() | ||
|
||
# MSE loss is already a scalar | ||
# loss_mse = nn.MSELoss()(output, data) | ||
|
||
# Combine losses (you can adjust the weights) | ||
# loss =loss_perceptual# + 0.5 * loss_mse | ||
# LAB-specific loss | ||
loss_lab, l_loss, a_loss, b_loss, ssim, tv_loss, high_freq_loss = lab_loss(output, data) | ||
|
||
# Combine losses | ||
loss = 0.3 * loss_perceptual + 0.7 * loss_lab | ||
|
||
loss.backward() | ||
optimizer.step() | ||
|
||
total_loss += loss.item() | ||
|
||
print(f'Epoch {epoch}, Batch {batch_idx}, Total Loss: {loss.item():.4f}, ' | ||
f'DE: , L: {l_loss.item():.4f}, ' | ||
f'a: {a_loss.item():.4f}, b: {b_loss.item():.4f}, ' | ||
f'SSIM: {ssim.item():.4f}, TV: {tv_loss.item():.4f}, HF: {high_freq_loss.item():.4f}') | ||
print(f'pLoss: {loss_perceptual.item():.4f}', | ||
f'lr: {optimizer.param_groups[0]["lr"]:.8f}') | ||
return total_loss / len(train_loader) | ||
|
||
|
||
|
||
# 保存图像对比 | ||
def save_image_comparison(epoch, data, output): | ||
fig, axes = plt.subplots(2, 5, figsize=(15, 6)) | ||
for i in range(5): | ||
# 在保存图像对比函数中 | ||
axes[0, i].imshow(to_pil_image(kornia.color.lab_to_rgb(data[i].unsqueeze(0)).squeeze(0).cpu())) | ||
axes[0, i].axis('off') | ||
axes[0, i].set_title('Original') | ||
axes[1, i].imshow(to_pil_image(kornia.color.lab_to_rgb(output[i].unsqueeze(0)).squeeze(0).cpu().clamp(0, 1))) | ||
|
||
# axes[1, i].imshow(to_pil_image(output[i].cpu().clamp(-1, 1))) | ||
axes[1, i].axis('off') | ||
axes[1, i].set_title('Reconstructed') | ||
|
||
plt.tight_layout() | ||
plt.savefig(f'results/256*256LAB++stl10_epoch_{epoch}.png') | ||
plt.close() | ||
|
||
|
||
# 训练循环 | ||
num_epochs = 50 | ||
|
||
for epoch in range(1, num_epochs + 1): | ||
avg_loss= train(epoch) | ||
print(f'Epoch {epoch}, Average Loss: {avg_loss:.4f}') | ||
|
||
# 保存模型检查点 | ||
if epoch % 5 == 0: | ||
torch.save(model.state_dict(), f'checkpoints/model_epoch_{epoch}.pth') | ||
|
||
# 生成并保存图像对比 | ||
if epoch % 1 == 0: | ||
model.eval() | ||
with torch.no_grad(): | ||
data = next(iter(train_loader))[0][:5].to(device) | ||
output = model(data) | ||
save_image_comparison(epoch, data, output) | ||
|
||
print("Training completed!") |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import numpy as np | ||
|
||
from src.helpers.utils import get_scheduled_params | ||
|
||
def weighted_rate_loss(config, total_nbpp, total_qbpp, step_counter, ignore_schedule=False): | ||
""" | ||
Heavily penalize the rate with weight lambda_A >> lambda_B if it exceeds | ||
some target r_t, otherwise penalize with lambda_B | ||
""" | ||
lambda_A = get_scheduled_params(config.lambda_A, config.lambda_schedule, step_counter, ignore_schedule) | ||
lambda_B = get_scheduled_params(config.lambda_B, config.lambda_schedule, step_counter, ignore_schedule) | ||
|
||
assert lambda_A > lambda_B, "Expected lambda_A > lambda_B, got (A) {} <= (B) {}".format( | ||
lambda_A, lambda_B) | ||
|
||
target_bpp = get_scheduled_params(config.target_rate, config.target_schedule, step_counter, ignore_schedule) | ||
|
||
total_qbpp = total_qbpp.item() | ||
if total_qbpp > target_bpp: | ||
rate_penalty = lambda_A | ||
else: | ||
rate_penalty = lambda_B | ||
weighted_rate = rate_penalty * total_nbpp | ||
|
||
return weighted_rate, float(rate_penalty) | ||
|
||
def _non_saturating_loss(D_real_logits, D_gen_logits, D_real=None, D_gen=None): | ||
|
||
D_loss_real = F.binary_cross_entropy_with_logits(input=D_real_logits, | ||
target=torch.ones_like(D_real_logits)) | ||
D_loss_gen = F.binary_cross_entropy_with_logits(input=D_gen_logits, | ||
target=torch.zeros_like(D_gen_logits)) | ||
D_loss = D_loss_real + D_loss_gen | ||
|
||
G_loss = F.binary_cross_entropy_with_logits(input=D_gen_logits, | ||
target=torch.ones_like(D_gen_logits)) | ||
|
||
return D_loss, G_loss | ||
|
||
def _least_squares_loss(D_real, D_gen, D_real_logits=None, D_gen_logits=None): | ||
D_loss_real = torch.mean(torch.square(D_real - 1.0)) | ||
D_loss_gen = torch.mean(torch.square(D_gen)) | ||
D_loss = 0.5 * (D_loss_real + D_loss_gen) | ||
|
||
G_loss = 0.5 * torch.mean(torch.square(D_gen - 1.0)) | ||
|
||
return D_loss, G_loss | ||
|
||
def gan_loss(gan_loss_type, disc_out, mode='generator_loss'): | ||
|
||
if gan_loss_type == 'non_saturating': | ||
loss_fn = _non_saturating_loss | ||
elif gan_loss_type == 'least_squares': | ||
loss_fn = _least_squares_loss | ||
else: | ||
raise ValueError('Invalid GAN loss') | ||
|
||
D_loss, G_loss = loss_fn(D_real=disc_out.D_real, D_gen=disc_out.D_gen, | ||
D_real_logits=disc_out.D_real_logits, D_gen_logits=disc_out.D_gen_logits) | ||
|
||
loss = G_loss if mode == 'generator_loss' else D_loss | ||
|
||
return loss |
Empty file.
Oops, something went wrong.