Skip to content

Commit

Permalink
勉强能用的版本
Browse files Browse the repository at this point in the history
  • Loading branch information
181404010226 committed Aug 22, 2024
1 parent c723c4b commit 50c8a71
Show file tree
Hide file tree
Showing 22 changed files with 1,895 additions and 0 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,8 @@ __pycache__
*.json
*.bin
*.pth
*.txt
STL10Data
!loss/perceptual_similarity/weights/v0.1/vgg.pth
!loss/perceptual_similarity/weights/v0.1/squeeze.pth
!loss/perceptual_similarity/weights/v0.1/alex.pth
131 changes: 131 additions & 0 deletions DualConvMixer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

# class Residual(nn.Module):
# def __init__(self, fn):
# super().__init__()
# self.fn = fn

# def forward(self, x):
# return self.fn(x) + x


# class ConvBlock(nn.Module):
# def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
# super(ConvBlock, self).__init__()
# self.conv = nn.Sequential(
# Residual(nn.Sequential(
# nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding),
# nn.GELU(),
# nn.BatchNorm2d(in_channels)
# )),
# nn.Conv2d(in_channels, out_channels, kernel_size=1),
# nn.GELU(),
# nn.BatchNorm2d(out_channels)
# )

# def forward(self, x):
# return self.conv(x)

class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
super(ConvBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)

def forward(self, x):
return self.relu(self.bn(self.conv(x)))

class UpBlock(nn.Module):
def __init__(self, in_channels, out_channels,kernel_size=3, stride=1, padding=1):
super(UpBlock, self).__init__()
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
# self.upsample = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=False)
self.conv = ConvBlock(in_channels, out_channels,kernel_size=kernel_size, stride=stride, padding=padding)

def forward(self, x):
x = self.upsample(x)
return self.conv(x)

# class DeconvBlock(nn.Module):
# def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, output_padding=0):
# super(DeconvBlock, self).__init__()
# self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, output_padding)
# self.bn = nn.BatchNorm2d(out_channels)
# self.relu = nn.ReLU(inplace=True)

# def forward(self, x):
# return self.relu(self.bn(self.deconv(x)))

class HourglassModel(nn.Module):
def __init__(self, input_channels=3, latent_dim=128):
super(HourglassModel, self).__init__()

# Encoder (正卷积部分)
self.encoder = nn.Sequential(
ConvBlock(input_channels, 64),
ConvBlock(64, 64),
nn.MaxPool2d(2),
ConvBlock(64, 128),
ConvBlock(128, 128),
nn.MaxPool2d(2),
ConvBlock(128, 256),
ConvBlock(256, 256),
nn.MaxPool2d(2),
ConvBlock(256, 512),
ConvBlock(512, 512),
nn.MaxPool2d(2),
ConvBlock(512, latent_dim)
)

# Decoder
self.decoder = nn.Sequential(
UpBlock(latent_dim, 512),
UpBlock(512, 256),
UpBlock(256, 128),
UpBlock(128, 64),
nn.Conv2d(64, input_channels, kernel_size=3, stride=1, padding=1),
nn.Tanh()
)

# # Decoder (反卷积部分)
# self.decoder = nn.Sequential(
# DeconvBlock(latent_dim, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
# DeconvBlock(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
# DeconvBlock(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
# DeconvBlock(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
# nn.Conv2d(64, input_channels, kernel_size=3, stride=1, padding=1),
# nn.Tanh()
# )

def forward(self, x):
latent = self.encoder(x)
output = self.decoder(latent)
return output

def encode(self, x):
return self.encoder(x)

def decode(self, latent):
return self.decoder(latent)

# 测试模型
if __name__ == "__main__":
# 创建一个示例输入
batch_size = 1
channels = 3
height, width = 32, 32
x = torch.randn(batch_size, channels, height, width)

# 初始化模型
model = HourglassModel()

# 前向传播
output = model(x)

print(model.encode(x).shape)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
# print(f"Model architecture:\n{model}")
163 changes: 163 additions & 0 deletions DualTrain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
from DualConvMixer import HourglassModel
from torchvision.models import vgg16
from torchvision.transforms.functional import to_pil_image
from torch.optim.lr_scheduler import OneCycleLR
from network.AutoEncoder import AutoEncoder
import numpy as np
import torch.nn.functional as F
from torchvision.models import vgg16
from loss.perceptual_similarity.perceptual_loss import PerceptualLoss
import kornia
import colour
import numpy as np

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 创建保存结果的文件夹
os.makedirs("results", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)

# 数据预处理

transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Lambda(lambda x: kornia.color.rgb_to_lab(x.unsqueeze(0)).squeeze(0))
])

batch_size = 128
# 加载STL10数据集
train_dataset = datasets.STL10(root='./STL10Data', split='train+unlabeled', download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)

# 加载CIFAR-10数据集
# train_dataset = datasets.CIFAR10(root='./CIFAR10RawData', train=True, download=True, transform=transform)
# train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)

# 初始化模型
# model = HourglassModel().to(device)
model = AutoEncoder(image_dims=(3, 256, 256), batch_size=batch_size,
C=16,activation='leaky_relu').to(device)

perceptual_loss = PerceptualLoss(model='net-lin', net='alex',
colorspace='Lab', use_gpu=torch.cuda.is_available()).to(device)
# 优化器
optimizer = optim.AdamW(model.parameters(), lr=0.001)

import torch.nn.functional as F
from kornia.losses import ssim_loss #psnr_loss

def lab_loss(output, target):
# Calculate PSNR loss
# psnr = psnr_loss(output, target, max_val=255.0) # Assuming 8-bit color depth
# psnr_loss_value = 1.0 - psnr / 100.0 # Convert PSNR to a loss (0 to 1 range)

# Separate channel losses
l_loss = F.mse_loss(output[:, 0], target[:, 0]) / (100.0 ** 2) # Normalize by square of range
a_loss = F.mse_loss(output[:, 1], target[:, 1]) / (255.0 ** 2) # Normalize by square of range
b_loss = F.mse_loss(output[:, 2], target[:, 2]) / (255.0 ** 2) # Normalize by square of range

# SSIM loss (applied on L channel)
ssim = ssim_loss(output[:, 0].unsqueeze(1), target[:, 0].unsqueeze(1), window_size=11)

# Total Variation loss (you might want to normalize this too)
tv_loss = (torch.mean(torch.abs(output[:, :, :, :-1] - output[:, :, :, 1:])) +
torch.mean(torch.abs(output[:, :, :-1, :] - output[:, :, 1:, :]))) / 255.0

high_freq_output = output - F.avg_pool2d(output, kernel_size=3, stride=1, padding=1)
high_freq_target = target - F.avg_pool2d(target, kernel_size=3, stride=1, padding=1)
high_freq_loss = F.mse_loss(high_freq_output, high_freq_target)

# Combine losses (adjust weights as needed)
total_loss = l_loss + 0.5 * (a_loss + b_loss) + 0.1 * ssim + 0.01 * tv_loss+ 0.1 * high_freq_loss
return total_loss, l_loss, a_loss, b_loss, ssim, tv_loss, high_freq_loss # Return individual losses for monitoring
# total_loss = 0.5 * psnr_loss_value + 0.2 * (l_loss + 0.5 * (a_loss + b_loss)) + 0.1 * ssim + 0.01 * tv_loss
# return total_loss, psnr_loss_value, l_loss, a_loss, b_loss, ssim, tv_loss


# 训练函数
def train(epoch):
model.train()
total_loss = 0
for batch_idx, (data, _) in enumerate(tqdm(train_loader)):
data = data.to(device)
optimizer.zero_grad()
output = model(data)

# Calculate perceptual loss and take the mean to get a scalar
loss_perceptual = perceptual_loss(output, data).mean()

# MSE loss is already a scalar
# loss_mse = nn.MSELoss()(output, data)

# Combine losses (you can adjust the weights)
# loss =loss_perceptual# + 0.5 * loss_mse
# LAB-specific loss
loss_lab, l_loss, a_loss, b_loss, ssim, tv_loss, high_freq_loss = lab_loss(output, data)

# Combine losses
loss = 0.3 * loss_perceptual + 0.7 * loss_lab

loss.backward()
optimizer.step()

total_loss += loss.item()

print(f'Epoch {epoch}, Batch {batch_idx}, Total Loss: {loss.item():.4f}, '
f'DE: , L: {l_loss.item():.4f}, '
f'a: {a_loss.item():.4f}, b: {b_loss.item():.4f}, '
f'SSIM: {ssim.item():.4f}, TV: {tv_loss.item():.4f}, HF: {high_freq_loss.item():.4f}')
print(f'pLoss: {loss_perceptual.item():.4f}',
f'lr: {optimizer.param_groups[0]["lr"]:.8f}')
return total_loss / len(train_loader)



# 保存图像对比
def save_image_comparison(epoch, data, output):
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i in range(5):
# 在保存图像对比函数中
axes[0, i].imshow(to_pil_image(kornia.color.lab_to_rgb(data[i].unsqueeze(0)).squeeze(0).cpu()))
axes[0, i].axis('off')
axes[0, i].set_title('Original')
axes[1, i].imshow(to_pil_image(kornia.color.lab_to_rgb(output[i].unsqueeze(0)).squeeze(0).cpu().clamp(0, 1)))

# axes[1, i].imshow(to_pil_image(output[i].cpu().clamp(-1, 1)))
axes[1, i].axis('off')
axes[1, i].set_title('Reconstructed')

plt.tight_layout()
plt.savefig(f'results/256*256LAB++stl10_epoch_{epoch}.png')
plt.close()


# 训练循环
num_epochs = 50

for epoch in range(1, num_epochs + 1):
avg_loss= train(epoch)
print(f'Epoch {epoch}, Average Loss: {avg_loss:.4f}')

# 保存模型检查点
if epoch % 5 == 0:
torch.save(model.state_dict(), f'checkpoints/model_epoch_{epoch}.pth')

# 生成并保存图像对比
if epoch % 1 == 0:
model.eval()
with torch.no_grad():
data = next(iter(train_loader))[0][:5].to(device)
output = model(data)
save_image_comparison(epoch, data, output)

print("Training completed!")
Empty file added loss/__init__.py
Empty file.
66 changes: 66 additions & 0 deletions loss/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from src.helpers.utils import get_scheduled_params

def weighted_rate_loss(config, total_nbpp, total_qbpp, step_counter, ignore_schedule=False):
"""
Heavily penalize the rate with weight lambda_A >> lambda_B if it exceeds
some target r_t, otherwise penalize with lambda_B
"""
lambda_A = get_scheduled_params(config.lambda_A, config.lambda_schedule, step_counter, ignore_schedule)
lambda_B = get_scheduled_params(config.lambda_B, config.lambda_schedule, step_counter, ignore_schedule)

assert lambda_A > lambda_B, "Expected lambda_A > lambda_B, got (A) {} <= (B) {}".format(
lambda_A, lambda_B)

target_bpp = get_scheduled_params(config.target_rate, config.target_schedule, step_counter, ignore_schedule)

total_qbpp = total_qbpp.item()
if total_qbpp > target_bpp:
rate_penalty = lambda_A
else:
rate_penalty = lambda_B
weighted_rate = rate_penalty * total_nbpp

return weighted_rate, float(rate_penalty)

def _non_saturating_loss(D_real_logits, D_gen_logits, D_real=None, D_gen=None):

D_loss_real = F.binary_cross_entropy_with_logits(input=D_real_logits,
target=torch.ones_like(D_real_logits))
D_loss_gen = F.binary_cross_entropy_with_logits(input=D_gen_logits,
target=torch.zeros_like(D_gen_logits))
D_loss = D_loss_real + D_loss_gen

G_loss = F.binary_cross_entropy_with_logits(input=D_gen_logits,
target=torch.ones_like(D_gen_logits))

return D_loss, G_loss

def _least_squares_loss(D_real, D_gen, D_real_logits=None, D_gen_logits=None):
D_loss_real = torch.mean(torch.square(D_real - 1.0))
D_loss_gen = torch.mean(torch.square(D_gen))
D_loss = 0.5 * (D_loss_real + D_loss_gen)

G_loss = 0.5 * torch.mean(torch.square(D_gen - 1.0))

return D_loss, G_loss

def gan_loss(gan_loss_type, disc_out, mode='generator_loss'):

if gan_loss_type == 'non_saturating':
loss_fn = _non_saturating_loss
elif gan_loss_type == 'least_squares':
loss_fn = _least_squares_loss
else:
raise ValueError('Invalid GAN loss')

D_loss, G_loss = loss_fn(D_real=disc_out.D_real, D_gen=disc_out.D_gen,
D_real_logits=disc_out.D_real_logits, D_gen_logits=disc_out.D_gen_logits)

loss = G_loss if mode == 'generator_loss' else D_loss

return loss
Empty file.
Loading

0 comments on commit 50c8a71

Please sign in to comment.