forked from TencentARC/GFPGAN
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
750 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
00000000.png (512,512,3) 1 |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
name: UnitTest | ||
type: FFHQDegradationDataset | ||
dataroot_gt: tests/data/gt | ||
io_backend: | ||
type: disk | ||
|
||
use_hflip: true | ||
mean: [0.5, 0.5, 0.5] | ||
std: [0.5, 0.5, 0.5] | ||
out_size: 512 | ||
|
||
blur_kernel_size: 41 | ||
kernel_list: ['iso', 'aniso'] | ||
kernel_prob: [0.5, 0.5] | ||
blur_sigma: [0.1, 10] | ||
downsample_range: [0.8, 8] | ||
noise_range: [0, 20] | ||
jpeg_range: [60, 100] | ||
|
||
# color jitter and gray | ||
color_jitter_prob: 1 | ||
color_jitter_shift: 20 | ||
color_jitter_pt_prob: 1 | ||
gray_prob: 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
num_gpu: 1 | ||
manual_seed: 0 | ||
is_train: True | ||
dist: False | ||
|
||
# network structures | ||
network_g: | ||
type: GFPGANv1 | ||
out_size: 512 | ||
num_style_feat: 512 | ||
channel_multiplier: 1 | ||
resample_kernel: [1, 3, 3, 1] | ||
decoder_load_path: ~ | ||
fix_decoder: true | ||
num_mlp: 8 | ||
lr_mlp: 0.01 | ||
input_is_latent: true | ||
different_w: true | ||
narrow: 0.5 | ||
sft_half: true | ||
|
||
network_d: | ||
type: StyleGAN2Discriminator | ||
out_size: 512 | ||
channel_multiplier: 1 | ||
resample_kernel: [1, 3, 3, 1] | ||
|
||
network_d_left_eye: | ||
type: FacialComponentDiscriminator | ||
|
||
network_d_right_eye: | ||
type: FacialComponentDiscriminator | ||
|
||
network_d_mouth: | ||
type: FacialComponentDiscriminator | ||
|
||
network_identity: | ||
type: ResNetArcFace | ||
block: IRBlock | ||
layers: [2, 2, 2, 2] | ||
use_se: False | ||
|
||
# path | ||
path: | ||
pretrain_network_g: ~ | ||
param_key_g: params_ema | ||
strict_load_g: ~ | ||
pretrain_network_d: ~ | ||
pretrain_network_d_left_eye: ~ | ||
pretrain_network_d_right_eye: ~ | ||
pretrain_network_d_mouth: ~ | ||
pretrain_network_identity: ~ | ||
# resume | ||
resume_state: ~ | ||
ignore_resume_networks: ['network_identity'] | ||
|
||
# training settings | ||
train: | ||
optim_g: | ||
type: Adam | ||
lr: !!float 2e-3 | ||
optim_d: | ||
type: Adam | ||
lr: !!float 2e-3 | ||
optim_component: | ||
type: Adam | ||
lr: !!float 2e-3 | ||
|
||
scheduler: | ||
type: MultiStepLR | ||
milestones: [600000, 700000] | ||
gamma: 0.5 | ||
|
||
total_iter: 800000 | ||
warmup_iter: -1 # no warm up | ||
|
||
# losses | ||
# pixel loss | ||
pixel_opt: | ||
type: L1Loss | ||
loss_weight: !!float 1e-1 | ||
reduction: mean | ||
# L1 loss used in pyramid loss, component style loss and identity loss | ||
L1_opt: | ||
type: L1Loss | ||
loss_weight: 1 | ||
reduction: mean | ||
|
||
# image pyramid loss | ||
pyramid_loss_weight: 1 | ||
remove_pyramid_loss: 50000 | ||
# perceptual loss (content and style losses) | ||
perceptual_opt: | ||
type: PerceptualLoss | ||
layer_weights: | ||
# before relu | ||
'conv1_2': 0.1 | ||
'conv2_2': 0.1 | ||
'conv3_4': 1 | ||
'conv4_4': 1 | ||
'conv5_4': 1 | ||
vgg_type: vgg19 | ||
use_input_norm: true | ||
perceptual_weight: !!float 1 | ||
style_weight: 50 | ||
range_norm: true | ||
criterion: l1 | ||
# gan loss | ||
gan_opt: | ||
type: GANLoss | ||
gan_type: wgan_softplus | ||
loss_weight: !!float 1e-1 | ||
# r1 regularization for discriminator | ||
r1_reg_weight: 10 | ||
# facial component loss | ||
gan_component_opt: | ||
type: GANLoss | ||
gan_type: vanilla | ||
real_label_val: 1.0 | ||
fake_label_val: 0.0 | ||
loss_weight: !!float 1 | ||
comp_style_weight: 200 | ||
# identity loss | ||
identity_weight: 10 | ||
|
||
net_d_iters: 1 | ||
net_d_init_iters: 0 | ||
net_d_reg_every: 1 | ||
|
||
# validation settings | ||
val: | ||
val_freq: !!float 5e3 | ||
save_img: True | ||
use_pbar: True | ||
|
||
metrics: | ||
psnr: # metric name | ||
type: calculate_psnr | ||
crop_border: 0 | ||
test_y_channel: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import torch | ||
|
||
from gfpgan.archs.arcface_arch import BasicBlock, Bottleneck, ResNetArcFace | ||
|
||
|
||
def test_resnetarcface(): | ||
"""Test arch: ResNetArcFace.""" | ||
|
||
# model init and forward (gpu) | ||
if torch.cuda.is_available(): | ||
net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=True).cuda().eval() | ||
img = torch.rand((1, 1, 128, 128), dtype=torch.float32).cuda() | ||
output = net(img) | ||
assert output.shape == (1, 512) | ||
|
||
# -------------------- without SE block ----------------------- # | ||
net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=False).cuda().eval() | ||
output = net(img) | ||
assert output.shape == (1, 512) | ||
|
||
|
||
def test_basicblock(): | ||
"""Test the BasicBlock in arcface_arch""" | ||
block = BasicBlock(1, 3, stride=1, downsample=None).cuda() | ||
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda() | ||
output = block(img) | ||
assert output.shape == (1, 3, 12, 12) | ||
|
||
# ----------------- use the downsmaple module--------------- # | ||
downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda() | ||
block = BasicBlock(1, 3, stride=2, downsample=downsample).cuda() | ||
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda() | ||
output = block(img) | ||
assert output.shape == (1, 3, 6, 6) | ||
|
||
|
||
def test_bottleneck(): | ||
"""Test the Bottleneck in arcface_arch""" | ||
block = Bottleneck(1, 1, stride=1, downsample=None).cuda() | ||
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda() | ||
output = block(img) | ||
assert output.shape == (1, 4, 12, 12) | ||
|
||
# ----------------- use the downsmaple module--------------- # | ||
downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda() | ||
block = Bottleneck(1, 1, stride=2, downsample=downsample).cuda() | ||
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda() | ||
output = block(img) | ||
assert output.shape == (1, 4, 6, 6) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import pytest | ||
import yaml | ||
|
||
from gfpgan.data.ffhq_degradation_dataset import FFHQDegradationDataset | ||
|
||
|
||
def test_ffhq_degradation_dataset(): | ||
|
||
with open('tests/data/test_ffhq_degradation_dataset.yml', mode='r') as f: | ||
opt = yaml.load(f, Loader=yaml.FullLoader) | ||
|
||
dataset = FFHQDegradationDataset(opt) | ||
assert dataset.io_backend_opt['type'] == 'disk' # io backend | ||
assert len(dataset) == 1 # whether to read correct meta info | ||
assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations | ||
assert dataset.color_jitter_prob == 1 | ||
|
||
# test __getitem__ | ||
result = dataset.__getitem__(0) | ||
# check returned keys | ||
expected_keys = ['gt', 'lq', 'gt_path'] | ||
assert set(expected_keys).issubset(set(result.keys())) | ||
# check shape and contents | ||
assert result['gt'].shape == (3, 512, 512) | ||
assert result['lq'].shape == (3, 512, 512) | ||
assert result['gt_path'] == 'tests/data/gt/00000000.png' | ||
|
||
# ------------------ test with probability = 0 -------------------- # | ||
opt['color_jitter_prob'] = 0 | ||
opt['color_jitter_pt_prob'] = 0 | ||
opt['gray_prob'] = 0 | ||
opt['io_backend'] = dict(type='disk') | ||
dataset = FFHQDegradationDataset(opt) | ||
assert dataset.io_backend_opt['type'] == 'disk' # io backend | ||
assert len(dataset) == 1 # whether to read correct meta info | ||
assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations | ||
assert dataset.color_jitter_prob == 0 | ||
|
||
# test __getitem__ | ||
result = dataset.__getitem__(0) | ||
# check returned keys | ||
expected_keys = ['gt', 'lq', 'gt_path'] | ||
assert set(expected_keys).issubset(set(result.keys())) | ||
# check shape and contents | ||
assert result['gt'].shape == (3, 512, 512) | ||
assert result['lq'].shape == (3, 512, 512) | ||
assert result['gt_path'] == 'tests/data/gt/00000000.png' | ||
|
||
# ------------------ test lmdb backend -------------------- # | ||
opt['dataroot_gt'] = 'tests/data/ffhq_gt.lmdb' | ||
opt['io_backend'] = dict(type='lmdb') | ||
|
||
dataset = FFHQDegradationDataset(opt) | ||
assert dataset.io_backend_opt['type'] == 'lmdb' # io backend | ||
assert len(dataset) == 1 # whether to read correct meta info | ||
assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations | ||
assert dataset.color_jitter_prob == 0 | ||
|
||
# test __getitem__ | ||
result = dataset.__getitem__(0) | ||
# check returned keys | ||
expected_keys = ['gt', 'lq', 'gt_path'] | ||
assert set(expected_keys).issubset(set(result.keys())) | ||
# check shape and contents | ||
assert result['gt'].shape == (3, 512, 512) | ||
assert result['lq'].shape == (3, 512, 512) | ||
assert result['gt_path'] == '00000000' | ||
|
||
# ------------------ test with crop_components -------------------- # | ||
opt['crop_components'] = True | ||
opt['component_path'] = 'tests/data/test_eye_mouth_landmarks.pth' | ||
opt['eye_enlarge_ratio'] = 1.4 | ||
opt['gt_gray'] = True | ||
opt['io_backend'] = dict(type='lmdb') | ||
|
||
dataset = FFHQDegradationDataset(opt) | ||
assert dataset.crop_components is True | ||
|
||
# test __getitem__ | ||
result = dataset.__getitem__(0) | ||
# check returned keys | ||
expected_keys = ['gt', 'lq', 'gt_path', 'loc_left_eye', 'loc_right_eye', 'loc_mouth'] | ||
assert set(expected_keys).issubset(set(result.keys())) | ||
# check shape and contents | ||
assert result['gt'].shape == (3, 512, 512) | ||
assert result['lq'].shape == (3, 512, 512) | ||
assert result['gt_path'] == '00000000' | ||
assert result['loc_left_eye'].shape == (4, ) | ||
assert result['loc_right_eye'].shape == (4, ) | ||
assert result['loc_mouth'].shape == (4, ) | ||
|
||
# ------------------ lmdb backend should have paths ends with lmdb -------------------- # | ||
with pytest.raises(ValueError): | ||
opt['dataroot_gt'] = 'tests/data/gt' | ||
opt['io_backend'] = dict(type='lmdb') | ||
dataset = FFHQDegradationDataset(opt) |
Oops, something went wrong.