Skip to content

Commit

Permalink
update utils and unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
xinntao committed Nov 28, 2021
1 parent be73d6d commit 37237da
Show file tree
Hide file tree
Showing 15 changed files with 750 additions and 26 deletions.
28 changes: 3 additions & 25 deletions gfpgan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
import os
import torch
from basicsr.utils import img2tensor, tensor2img
from basicsr.utils.download_util import load_file_from_url
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from torch.hub import download_url_to_file, get_dir
from torchvision.transforms.functional import normalize
from urllib.parse import urlparse

from gfpgan.archs.gfpganv1_arch import GFPGANv1
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
Expand Down Expand Up @@ -70,7 +69,8 @@ def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg
device=self.device)

if model_path.startswith('https://'):
model_path = load_file_from_url(url=model_path, model_dir='gfpgan/weights', progress=True, file_name=None)
model_path = load_file_from_url(
url=model_path, model_dir=os.path.join(ROOT_DIR, 'gfpgan/weights'), progress=True, file_name=None)
loadnet = torch.load(model_path)
if 'params_ema' in loadnet:
keyname = 'params_ema'
Expand Down Expand Up @@ -128,25 +128,3 @@ def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=Tru
return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
else:
return self.face_helper.cropped_faces, self.face_helper.restored_faces, None


def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
"""Load file form http url, will download models if necessary.
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
"""
if model_dir is None:
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')

os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)

parts = urlparse(url)
filename = os.path.basename(parts.path)
if file_name is not None:
filename = file_name
cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
return cached_file
8 changes: 7 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,17 @@ line_length = 120
multi_line_output = 0
known_standard_library = pkg_resources,setuptools
known_first_party = gfpgan
known_third_party = basicsr,cv2,facexlib,numpy,torch,torchvision,tqdm
known_third_party = basicsr,cv2,facexlib,numpy,pytest,torch,torchvision,tqdm,yaml
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY

[codespell]
skip = .git,./docs/build
count =
quiet-level = 3

[aliases]
test=pytest

[tool:pytest]
addopts=tests/
Binary file added tests/data/ffhq_gt.lmdb/data.mdb
Binary file not shown.
Binary file added tests/data/ffhq_gt.lmdb/lock.mdb
Binary file not shown.
1 change: 1 addition & 0 deletions tests/data/ffhq_gt.lmdb/meta_info.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
00000000.png (512,512,3) 1
Binary file added tests/data/gt/00000000.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/test_eye_mouth_landmarks.pth
Binary file not shown.
24 changes: 24 additions & 0 deletions tests/data/test_ffhq_degradation_dataset.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: UnitTest
type: FFHQDegradationDataset
dataroot_gt: tests/data/gt
io_backend:
type: disk

use_hflip: true
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
out_size: 512

blur_kernel_size: 41
kernel_list: ['iso', 'aniso']
kernel_prob: [0.5, 0.5]
blur_sigma: [0.1, 10]
downsample_range: [0.8, 8]
noise_range: [0, 20]
jpeg_range: [60, 100]

# color jitter and gray
color_jitter_prob: 1
color_jitter_shift: 20
color_jitter_pt_prob: 1
gray_prob: 1
140 changes: 140 additions & 0 deletions tests/data/test_gfpgan_model.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
num_gpu: 1
manual_seed: 0
is_train: True
dist: False

# network structures
network_g:
type: GFPGANv1
out_size: 512
num_style_feat: 512
channel_multiplier: 1
resample_kernel: [1, 3, 3, 1]
decoder_load_path: ~
fix_decoder: true
num_mlp: 8
lr_mlp: 0.01
input_is_latent: true
different_w: true
narrow: 0.5
sft_half: true

network_d:
type: StyleGAN2Discriminator
out_size: 512
channel_multiplier: 1
resample_kernel: [1, 3, 3, 1]

network_d_left_eye:
type: FacialComponentDiscriminator

network_d_right_eye:
type: FacialComponentDiscriminator

network_d_mouth:
type: FacialComponentDiscriminator

network_identity:
type: ResNetArcFace
block: IRBlock
layers: [2, 2, 2, 2]
use_se: False

# path
path:
pretrain_network_g: ~
param_key_g: params_ema
strict_load_g: ~
pretrain_network_d: ~
pretrain_network_d_left_eye: ~
pretrain_network_d_right_eye: ~
pretrain_network_d_mouth: ~
pretrain_network_identity: ~
# resume
resume_state: ~
ignore_resume_networks: ['network_identity']

# training settings
train:
optim_g:
type: Adam
lr: !!float 2e-3
optim_d:
type: Adam
lr: !!float 2e-3
optim_component:
type: Adam
lr: !!float 2e-3

scheduler:
type: MultiStepLR
milestones: [600000, 700000]
gamma: 0.5

total_iter: 800000
warmup_iter: -1 # no warm up

# losses
# pixel loss
pixel_opt:
type: L1Loss
loss_weight: !!float 1e-1
reduction: mean
# L1 loss used in pyramid loss, component style loss and identity loss
L1_opt:
type: L1Loss
loss_weight: 1
reduction: mean

# image pyramid loss
pyramid_loss_weight: 1
remove_pyramid_loss: 50000
# perceptual loss (content and style losses)
perceptual_opt:
type: PerceptualLoss
layer_weights:
# before relu
'conv1_2': 0.1
'conv2_2': 0.1
'conv3_4': 1
'conv4_4': 1
'conv5_4': 1
vgg_type: vgg19
use_input_norm: true
perceptual_weight: !!float 1
style_weight: 50
range_norm: true
criterion: l1
# gan loss
gan_opt:
type: GANLoss
gan_type: wgan_softplus
loss_weight: !!float 1e-1
# r1 regularization for discriminator
r1_reg_weight: 10
# facial component loss
gan_component_opt:
type: GANLoss
gan_type: vanilla
real_label_val: 1.0
fake_label_val: 0.0
loss_weight: !!float 1
comp_style_weight: 200
# identity loss
identity_weight: 10

net_d_iters: 1
net_d_init_iters: 0
net_d_reg_every: 1

# validation settings
val:
val_freq: !!float 5e3
save_img: True
use_pbar: True

metrics:
psnr: # metric name
type: calculate_psnr
crop_border: 0
test_y_channel: false
49 changes: 49 additions & 0 deletions tests/test_arcface_arch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch

from gfpgan.archs.arcface_arch import BasicBlock, Bottleneck, ResNetArcFace


def test_resnetarcface():
"""Test arch: ResNetArcFace."""

# model init and forward (gpu)
if torch.cuda.is_available():
net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=True).cuda().eval()
img = torch.rand((1, 1, 128, 128), dtype=torch.float32).cuda()
output = net(img)
assert output.shape == (1, 512)

# -------------------- without SE block ----------------------- #
net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=False).cuda().eval()
output = net(img)
assert output.shape == (1, 512)


def test_basicblock():
"""Test the BasicBlock in arcface_arch"""
block = BasicBlock(1, 3, stride=1, downsample=None).cuda()
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
output = block(img)
assert output.shape == (1, 3, 12, 12)

# ----------------- use the downsmaple module--------------- #
downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda()
block = BasicBlock(1, 3, stride=2, downsample=downsample).cuda()
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
output = block(img)
assert output.shape == (1, 3, 6, 6)


def test_bottleneck():
"""Test the Bottleneck in arcface_arch"""
block = Bottleneck(1, 1, stride=1, downsample=None).cuda()
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
output = block(img)
assert output.shape == (1, 4, 12, 12)

# ----------------- use the downsmaple module--------------- #
downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda()
block = Bottleneck(1, 1, stride=2, downsample=downsample).cuda()
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
output = block(img)
assert output.shape == (1, 4, 6, 6)
96 changes: 96 additions & 0 deletions tests/test_ffhq_degradation_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import pytest
import yaml

from gfpgan.data.ffhq_degradation_dataset import FFHQDegradationDataset


def test_ffhq_degradation_dataset():

with open('tests/data/test_ffhq_degradation_dataset.yml', mode='r') as f:
opt = yaml.load(f, Loader=yaml.FullLoader)

dataset = FFHQDegradationDataset(opt)
assert dataset.io_backend_opt['type'] == 'disk' # io backend
assert len(dataset) == 1 # whether to read correct meta info
assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations
assert dataset.color_jitter_prob == 1

# test __getitem__
result = dataset.__getitem__(0)
# check returned keys
expected_keys = ['gt', 'lq', 'gt_path']
assert set(expected_keys).issubset(set(result.keys()))
# check shape and contents
assert result['gt'].shape == (3, 512, 512)
assert result['lq'].shape == (3, 512, 512)
assert result['gt_path'] == 'tests/data/gt/00000000.png'

# ------------------ test with probability = 0 -------------------- #
opt['color_jitter_prob'] = 0
opt['color_jitter_pt_prob'] = 0
opt['gray_prob'] = 0
opt['io_backend'] = dict(type='disk')
dataset = FFHQDegradationDataset(opt)
assert dataset.io_backend_opt['type'] == 'disk' # io backend
assert len(dataset) == 1 # whether to read correct meta info
assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations
assert dataset.color_jitter_prob == 0

# test __getitem__
result = dataset.__getitem__(0)
# check returned keys
expected_keys = ['gt', 'lq', 'gt_path']
assert set(expected_keys).issubset(set(result.keys()))
# check shape and contents
assert result['gt'].shape == (3, 512, 512)
assert result['lq'].shape == (3, 512, 512)
assert result['gt_path'] == 'tests/data/gt/00000000.png'

# ------------------ test lmdb backend -------------------- #
opt['dataroot_gt'] = 'tests/data/ffhq_gt.lmdb'
opt['io_backend'] = dict(type='lmdb')

dataset = FFHQDegradationDataset(opt)
assert dataset.io_backend_opt['type'] == 'lmdb' # io backend
assert len(dataset) == 1 # whether to read correct meta info
assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations
assert dataset.color_jitter_prob == 0

# test __getitem__
result = dataset.__getitem__(0)
# check returned keys
expected_keys = ['gt', 'lq', 'gt_path']
assert set(expected_keys).issubset(set(result.keys()))
# check shape and contents
assert result['gt'].shape == (3, 512, 512)
assert result['lq'].shape == (3, 512, 512)
assert result['gt_path'] == '00000000'

# ------------------ test with crop_components -------------------- #
opt['crop_components'] = True
opt['component_path'] = 'tests/data/test_eye_mouth_landmarks.pth'
opt['eye_enlarge_ratio'] = 1.4
opt['gt_gray'] = True
opt['io_backend'] = dict(type='lmdb')

dataset = FFHQDegradationDataset(opt)
assert dataset.crop_components is True

# test __getitem__
result = dataset.__getitem__(0)
# check returned keys
expected_keys = ['gt', 'lq', 'gt_path', 'loc_left_eye', 'loc_right_eye', 'loc_mouth']
assert set(expected_keys).issubset(set(result.keys()))
# check shape and contents
assert result['gt'].shape == (3, 512, 512)
assert result['lq'].shape == (3, 512, 512)
assert result['gt_path'] == '00000000'
assert result['loc_left_eye'].shape == (4, )
assert result['loc_right_eye'].shape == (4, )
assert result['loc_mouth'].shape == (4, )

# ------------------ lmdb backend should have paths ends with lmdb -------------------- #
with pytest.raises(ValueError):
opt['dataroot_gt'] = 'tests/data/gt'
opt['io_backend'] = dict(type='lmdb')
dataset = FFHQDegradationDataset(opt)
Loading

0 comments on commit 37237da

Please sign in to comment.