Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 28 additions & 17 deletions data/options.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
import argparse

def _str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')

def option():
# Training settings
parser = argparse.ArgumentParser(description='CIDNet')
Expand All @@ -9,17 +17,17 @@ def option():
parser.add_argument('--start_epoch', type=int, default=0, help='number of epochs to start, >0 is retrained a pre-trained pth')
parser.add_argument('--snapshots', type=int, default=10, help='Snapshots for save checkpoints pth')
parser.add_argument('--lr', type=float, default=1e-4, help='Learning Rate')
parser.add_argument('--gpu_mode', type=bool, default=True)
parser.add_argument('--shuffle', type=bool, default=True)
parser.add_argument('--gpu_mode', type=_str2bool, default=True)
parser.add_argument('--shuffle', type=_str2bool, default=True)
parser.add_argument('--threads', type=int, default=16, help='number of threads for dataloader to use')

# choose a scheduler
parser.add_argument('--cos_restart_cyclic', type=bool, default=False)
parser.add_argument('--cos_restart', type=bool, default=True)
parser.add_argument('--cos_restart_cyclic', type=_str2bool, default=False)
parser.add_argument('--cos_restart', type=_str2bool, default=True)

# warmup training
parser.add_argument('--warmup_epochs', type=int, default=3, help='warmup_epochs')
parser.add_argument('--start_warmup', type=bool, default=True, help='turn False to train without warmup')
parser.add_argument('--start_warmup', type=_str2bool, default=True, help='turn False to train without warmup')

# train datasets
parser.add_argument('--data_train_lol_blur' , type=str, default='./datasets/LOL_blur/train')
Expand Down Expand Up @@ -60,22 +68,25 @@ def option():
parser.add_argument('--P_weight', type=float, default=1e-2)

# use random gamma function (enhancement curve) to improve generalization
parser.add_argument('--gamma', type=bool, default=False)
parser.add_argument('--gamma', type=_str2bool, default=False)
parser.add_argument('--start_gamma', type=int, default=60)
parser.add_argument('--end_gamma', type=int, default=120)

# auto grad, turn off to speed up training
parser.add_argument('--grad_detect', type=bool, default=False, help='if gradient explosion occurs, turn-on it')
parser.add_argument('--grad_clip', type=bool, default=True, help='if gradient fluctuates too much, turn-on it')
parser.add_argument('--grad_detect', type=_str2bool, default=False, help='if gradient explosion occurs, turn-on it')
parser.add_argument('--grad_clip', type=_str2bool, default=True, help='if gradient fluctuates too much, turn-on it')


# choose which dataset you want to train, please only set one "True"
parser.add_argument('--lol_v1', type=bool, default=True)
parser.add_argument('--lolv2_real', type=bool, default=False)
parser.add_argument('--lolv2_syn', type=bool, default=False)
parser.add_argument('--lol_blur', type=bool, default=False)
parser.add_argument('--SID', type=bool, default=False)
parser.add_argument('--SICE_mix', type=bool, default=False)
parser.add_argument('--SICE_grad', type=bool, default=False)
parser.add_argument('--fivek', type=bool, default=False)
# choose which dataset you want to train
parser.add_argument('--dataset', type=str, default='lol_v1',
choices=['lol_v1',
'lolv2_real',
'lolv2_syn',
'lol_blur',
'SID',
'SICE_mix',
'SICE_grad',
'fivek'],
help='Select the dataset to train on (default: %(default)s)')

return parser
55 changes: 28 additions & 27 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,6 @@
from loss.losses import *
from net.CIDNet import CIDNet

eval_parser = argparse.ArgumentParser(description='Eval')
eval_parser.add_argument('--perc', action='store_true', help='trained with perceptual loss')
eval_parser.add_argument('--lol', action='store_true', help='output lolv1 dataset')
eval_parser.add_argument('--lol_v2_real', action='store_true', help='output lol_v2_real dataset')
eval_parser.add_argument('--lol_v2_syn', action='store_true', help='output lol_v2_syn dataset')
eval_parser.add_argument('--SICE_grad', action='store_true', help='output SICE_grad dataset')
eval_parser.add_argument('--SICE_mix', action='store_true', help='output SICE_mix dataset')
eval_parser.add_argument('--fivek', action='store_true', help='output FiveK dataset')

eval_parser.add_argument('--best_GT_mean', action='store_true', help='output lol_v2_real dataset best_GT_mean')
eval_parser.add_argument('--best_PSNR', action='store_true', help='output lol_v2_real dataset best_PSNR')
eval_parser.add_argument('--best_SSIM', action='store_true', help='output lol_v2_real dataset best_SSIM')

eval_parser.add_argument('--custome', action='store_true', help='output custome dataset')
eval_parser.add_argument('--custome_path', type=str, default='./YOLO')
eval_parser.add_argument('--unpaired', action='store_true', help='output unpaired dataset')
eval_parser.add_argument('--DICM', action='store_true', help='output DICM dataset')
eval_parser.add_argument('--LIME', action='store_true', help='output LIME dataset')
eval_parser.add_argument('--MEF', action='store_true', help='output MEF dataset')
eval_parser.add_argument('--NPE', action='store_true', help='output NPE dataset')
eval_parser.add_argument('--VV', action='store_true', help='output VV dataset')
eval_parser.add_argument('--alpha', type=float, default=1.0)
eval_parser.add_argument('--gamma', type=float, default=1.0)
eval_parser.add_argument('--unpaired_weights', type=str, default='./weights/LOLv2_syn/w_perc.pth')

ep = eval_parser.parse_args()


def eval(model, testing_data_loader, model_path, output_folder,norm_size=True,LOL=False,v2=False,unpaired=False,alpha=1.0,gamma=1.0):
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -79,6 +52,34 @@ def eval(model, testing_data_loader, model_path, output_folder,norm_size=True,LO

if __name__ == '__main__':

eval_parser = argparse.ArgumentParser(description='Eval')
eval_parser.add_argument('--perc', action='store_true', help='trained with perceptual loss')
eval_parser.add_argument('--lol', action='store_true', help='output lolv1 dataset')
eval_parser.add_argument('--lol_v2_real', action='store_true', help='output lol_v2_real dataset')
eval_parser.add_argument('--lol_v2_syn', action='store_true', help='output lol_v2_syn dataset')
eval_parser.add_argument('--SICE_grad', action='store_true', help='output SICE_grad dataset')
eval_parser.add_argument('--SICE_mix', action='store_true', help='output SICE_mix dataset')
eval_parser.add_argument('--fivek', action='store_true', help='output FiveK dataset')

eval_parser.add_argument('--best_GT_mean', action='store_true', help='output lol_v2_real dataset best_GT_mean')
eval_parser.add_argument('--best_PSNR', action='store_true', help='output lol_v2_real dataset best_PSNR')
eval_parser.add_argument('--best_SSIM', action='store_true', help='output lol_v2_real dataset best_SSIM')

eval_parser.add_argument('--custome', action='store_true', help='output custome dataset')
eval_parser.add_argument('--custome_path', type=str, default='./YOLO')
eval_parser.add_argument('--unpaired', action='store_true', help='output unpaired dataset')
eval_parser.add_argument('--DICM', action='store_true', help='output DICM dataset')
eval_parser.add_argument('--LIME', action='store_true', help='output LIME dataset')
eval_parser.add_argument('--MEF', action='store_true', help='output MEF dataset')
eval_parser.add_argument('--NPE', action='store_true', help='output NPE dataset')
eval_parser.add_argument('--VV', action='store_true', help='output VV dataset')
eval_parser.add_argument('--alpha', type=float, default=1.0)
eval_parser.add_argument('--gamma', type=float, default=1.0)
eval_parser.add_argument('--unpaired_weights', type=str, default='./weights/LOLv2_syn/w_perc.pth')

ep = eval_parser.parse_args()


cuda = True
if cuda and not torch.cuda.is_available():
raise Exception("No GPU found, or need to change CUDA_VISIBLE_DEVICES number")
Expand Down
20 changes: 10 additions & 10 deletions eval_SID_blur.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,6 @@
from net.CIDNet import CIDNet


eval_parser = argparse.ArgumentParser(description='Eval')
eval_parser.add_argument('--SID', action='store_true')
eval_parser.add_argument('--Blur', action='store_true')
ep = eval_parser.parse_args()

cuda = True
if cuda and not torch.cuda.is_available():
raise Exception("No GPU found, please run without --cuda")


def eval(model, testing_data_loader, model_path, output_folder):
torch.set_grad_enabled(False)
model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage))
Expand All @@ -43,6 +33,16 @@ def eval(model, testing_data_loader, model_path, output_folder):

if __name__ == '__main__':

eval_parser = argparse.ArgumentParser(description='Eval')
eval_parser.add_argument('--SID', action='store_true')
eval_parser.add_argument('--Blur', action='store_true')
ep = eval_parser.parse_args()

cuda = True
if cuda and not torch.cuda.is_available():
raise Exception("No GPU found, please run without --cuda")


net = CIDNet().cuda()
if ep.Blur:
for index in range(1,257):
Expand Down
72 changes: 37 additions & 35 deletions eval_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@
import platform
from PIL import Image

eval_parser = argparse.ArgumentParser(description='EvalHF')
eval_parser.add_argument('--path', type=str, default="Fediory/HVI-CIDNet-LOLv1-wperc", help='You can change this path to our method weights mentioned here: https://huggingface.co/papers/2502.20272.')
eval_parser.add_argument('--input_img', type=str, default="../datasets/DICM/01.jpg", help='The path of your image.')
eval_parser.add_argument('--alpha_s', type=float, default=1.0)
eval_parser.add_argument('--alpha_i', type=float, default=1.0)
eval_parser.add_argument('--gamma', type=float, default=1.0)
el = eval_parser.parse_args()

def from_pretrained(cls, pretrained_model_name_or_path: str):
model_id = str(pretrained_model_name_or_path)
Expand All @@ -35,34 +28,43 @@ def from_pretrained(cls, pretrained_model_name_or_path: str):
return cls


if __name__ == '__main__':

model = CIDNet().cuda()
model = from_pretrained(cls=model,pretrained_model_name_or_path=el.path)
model.eval()
eval_parser = argparse.ArgumentParser(description='EvalHF')
eval_parser.add_argument('--path', type=str, default="Fediory/HVI-CIDNet-LOLv1-wperc", help='You can change this path to our method weights mentioned here: https://huggingface.co/papers/2502.20272.')
eval_parser.add_argument('--input_img', type=str, default="../datasets/DICM/01.jpg", help='The path of your image.')
eval_parser.add_argument('--alpha_s', type=float, default=1.0)
eval_parser.add_argument('--alpha_i', type=float, default=1.0)
eval_parser.add_argument('--gamma', type=float, default=1.0)
el = eval_parser.parse_args()

pil2tensor = transforms.Compose([transforms.ToTensor()])
img = Image.open(el.input_img).convert('RGB')
input = pil2tensor(img)
factor = 8
h, w = input.shape[1], input.shape[2]
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
padh = H - h if h % factor != 0 else 0
padw = W - w if w % factor != 0 else 0
input = F.pad(input.unsqueeze(0), (0,padw,0,padh), 'reflect')
with torch.no_grad():
model.trans.alpha_s = el.alpha_s
model.trans.alpha = el.alpha_i
model.trans.gated = True
model.trans.gated2 = True
output = model(input.cuda()**el.gamma)

model = CIDNet().cuda()
model = from_pretrained(cls=model,pretrained_model_name_or_path=el.path)
model.eval()

output = torch.clamp(output.cuda(),0,1).cuda()
output = output[:, :, :h, :w]
enhanced_img = transforms.ToPILImage()(output.squeeze(0))
output_folder = './output_hf'
if not os.path.exists(output_folder):
os.mkdir(output_folder)
item = el.input_img
name = item.split('/')[-1]
enhanced_img.save(output_folder + "/" + name)
pil2tensor = transforms.Compose([transforms.ToTensor()])
img = Image.open(el.input_img).convert('RGB')
input = pil2tensor(img)
factor = 8
h, w = input.shape[1], input.shape[2]
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
padh = H - h if h % factor != 0 else 0
padw = W - w if w % factor != 0 else 0
input = F.pad(input.unsqueeze(0), (0,padw,0,padh), 'reflect')
with torch.no_grad():
model.trans.alpha_s = el.alpha_s
model.trans.alpha = el.alpha_i
model.trans.gated = True
model.trans.gated2 = True
output = model(input.cuda()**el.gamma)


output = torch.clamp(output.cuda(),0,1).cuda()
output = output[:, :, :h, :w]
enhanced_img = transforms.ToPILImage()(output.squeeze(0))
output_folder = './output_hf'
if not os.path.exists(output_folder):
os.mkdir(output_folder)
item = el.input_img
name = item.split('/')[-1]
enhanced_img.save(output_folder + "/" + name)
20 changes: 11 additions & 9 deletions measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,7 @@
import argparse
import platform

mea_parser = argparse.ArgumentParser(description='Measure')
mea_parser.add_argument('--use_GT_mean', action='store_true', help='Use the mean of GT to rectify the output of the model')
mea_parser.add_argument('--lol', action='store_true', help='measure lolv1 dataset')
mea_parser.add_argument('--lol_v2_real', action='store_true', help='measure lol_v2_real dataset')
mea_parser.add_argument('--lol_v2_syn', action='store_true', help='measure lol_v2_syn dataset')
mea_parser.add_argument('--SICE_grad', action='store_true', help='measure SICE_grad dataset')
mea_parser.add_argument('--SICE_mix', action='store_true', help='measure SICE_mix dataset')
mea_parser.add_argument('--fivek', action='store_true', help='measure fivek dataset')
mea = mea_parser.parse_args()


def ssim(prediction, target):
C1 = (0.01 * 255)**2
Expand Down Expand Up @@ -123,6 +115,16 @@ def metrics(im_dir, label_dir, use_GT_mean):

if __name__ == '__main__':

mea_parser = argparse.ArgumentParser(description='Measure')
mea_parser.add_argument('--use_GT_mean', action='store_true', help='Use the mean of GT to rectify the output of the model')
mea_parser.add_argument('--lol', action='store_true', help='measure lolv1 dataset')
mea_parser.add_argument('--lol_v2_real', action='store_true', help='measure lol_v2_real dataset')
mea_parser.add_argument('--lol_v2_syn', action='store_true', help='measure lol_v2_syn dataset')
mea_parser.add_argument('--SICE_grad', action='store_true', help='measure SICE_grad dataset')
mea_parser.add_argument('--SICE_mix', action='store_true', help='measure SICE_mix dataset')
mea_parser.add_argument('--fivek', action='store_true', help='measure fivek dataset')
mea = mea_parser.parse_args()

if mea.lol:
im_dir = './output/LOLv1/*.png'
label_dir = './datasets/LOLdataset/eval15/high/'
Expand Down
11 changes: 6 additions & 5 deletions measure_SID_blur.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@
from os import listdir
import argparse

mea_parser = argparse.ArgumentParser(description='Measure')
mea_parser.add_argument('--use_GT_mean', action='store_true', help='Use the mean of GT to rectify the output of the model')
mea_parser.add_argument('--SID', action='store_true')
mea_parser.add_argument('--Blur', action='store_true')
mea = mea_parser.parse_args()

def is_image_file(filename):
return any(filename.endswith(extension) for extension in [".png", ".jpg", ".bmp", ".JPG", ".jpeg"])
Expand Down Expand Up @@ -115,6 +110,12 @@ def metrics(im_dir, label_dir, use_GT_mean):

if __name__ == '__main__':

mea_parser = argparse.ArgumentParser(description='Measure')
mea_parser.add_argument('--use_GT_mean', action='store_true', help='Use the mean of GT to rectify the output of the model')
mea_parser.add_argument('--SID', action='store_true')
mea_parser.add_argument('--Blur', action='store_true')
mea = mea_parser.parse_args()

avg_psnr = 0
avg_ssim = 0
avg_lpips = 0
Expand Down
16 changes: 8 additions & 8 deletions measure_niqe_bris.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,6 @@
from loss.niqe_utils import *
import argparse

eval_parser = argparse.ArgumentParser(description='Eval')
eval_parser.add_argument('--DICM', action='store_true', help='output DICM dataset')
eval_parser.add_argument('--LIME', action='store_true', help='output LIME dataset')
eval_parser.add_argument('--MEF', action='store_true', help='output MEF dataset')
eval_parser.add_argument('--NPE', action='store_true', help='output NPE dataset')
eval_parser.add_argument('--VV', action='store_true', help='output VV dataset')
ep = eval_parser.parse_args()


def metrics(im_dir):
avg_niqe = 0
Expand All @@ -39,6 +31,14 @@ def metrics(im_dir):

if __name__ == '__main__':

eval_parser = argparse.ArgumentParser(description='Eval')
eval_parser.add_argument('--DICM', action='store_true', help='output DICM dataset')
eval_parser.add_argument('--LIME', action='store_true', help='output LIME dataset')
eval_parser.add_argument('--MEF', action='store_true', help='output MEF dataset')
eval_parser.add_argument('--NPE', action='store_true', help='output NPE dataset')
eval_parser.add_argument('--VV', action='store_true', help='output VV dataset')
ep = eval_parser.parse_args()

if ep.DICM:
im_dir = './output/DICM/*.jpg'

Expand Down
Loading