-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
99 lines (83 loc) · 5.14 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
import argparse
from modules.dataloader import R2DataLoader
from modules.tokenizers import Tokenizer
from modules.loss import compute_loss
from modules.metrics import compute_scores
from models.models import MedCapModel
from modules.tester import Tester
import numpy as np
import os
os.environ['CURL_CA_BUNDLE'] = ''
def main():
parser = argparse.ArgumentParser()
# Data input Settings
parser.add_argument('--json_path', default='data/mimic_cxr/annotation.json',
help='Path to the json file')
parser.add_argument('--image_dir', default='data/mimic_cxr/images/',
help='Directory of images')
# Dataloader Settings
parser.add_argument('--dataset', default='iu_xray', help='dataset for training MedCap')
parser.add_argument('--bs', type=int, default=16)
parser.add_argument('--threshold', type=int, default=3, help='the cut off frequency for the words.')
parser.add_argument('--num_workers', type=int, default=2, help='the number of workers for dataloader.')
parser.add_argument('--max_seq_length', type=int, default=1024, help='the maximum sequence length of the reports.')
#Trainer Settings
parser.add_argument('--epochs', type=int, default=30)
parser.add_argument('--n_gpu', type=int, default=1, help='the number of gpus to be used.')
parser.add_argument('--save_dir', type=str, default='results/mimic_cxr/', help='the patch to save the models.')
parser.add_argument('--record_dir', type=str, default='./record_dir/',
help='the patch to save the results of experiments.')
parser.add_argument('--log_period', type=int, default=1000, help='the logging interval (in batches).')
parser.add_argument('--save_period', type=int, default=1)
parser.add_argument('--monitor_mode', type=str, default='max', choices=['min', 'max'], help='whether to max or min the metric.')
parser.add_argument('--monitor_metric', type=str, default='BLEU_4', help='the metric to be monitored.')
parser.add_argument('--early_stop', type=int, default=50, help='the patience of training.')
# Training related
parser.add_argument('--noise_inject', default='no', choices=['yes', 'no'])
# Sample related
parser.add_argument('--sample_method', type=str, default='greedy', help='the sample methods to sample a report.')
parser.add_argument('--prompt',default='/prompt/prompt.pt')
parser.add_argument('--prompt_load', default='yes',choices=['yes','no'])
# Optimization
parser.add_argument('--optim', type=str, default='Adam', help='the type of the optimizer.')
parser.add_argument('--lr_ve', type=float, default=5e-5, help='the learning rate for the visual extractor.')
parser.add_argument('--lr_ed', type=float, default=7e-4, help='the learning rate for the remaining parameters.')
parser.add_argument('--weight_decay', type=float, default=5e-5, help='the weight decay.')
parser.add_argument('--adam_betas', type=tuple, default=(0.9, 0.98), help='the weight decay.')
parser.add_argument('--adam_eps', type=float, default=1e-9, help='the weight decay.')
parser.add_argument('--amsgrad', type=bool, default=True, help='.')
parser.add_argument('--noamopt_warmup', type=int, default=5000, help='.')
parser.add_argument('--noamopt_factor', type=int, default=1, help='.')
# Learning Rate Scheduler
parser.add_argument('--lr_scheduler', type=str, default='StepLR', help='the type of the learning rate scheduler.')
parser.add_argument('--step_size', type=int, default=50, help='the step size of the learning rate scheduler.')
parser.add_argument('--gamma', type=float, default=0.1, help='the gamma of the learning rate scheduler.')
# Others
parser.add_argument('--seed', type=int, default=9153, help='.')
parser.add_argument('--resume', type=str, help='whether to resume the training from existing checkpoints.')
parser.add_argument('--train_mode', default='base', choices=['base', 'full'],
help='Training mode: base (text only training) or full (full supervised training)')
parser.add_argument('--full_supervised_version', default='v1', choices=['v1', 'v2' , 'v3'],
help='Full supervised version: v1 (only get image features) or v2 (feature fusion) or v3(feature fusion+image features')
parser.add_argument('--clip_update', default='no' , choices=['yes','no'])
parser.add_argument('--load', type=str, help='whether to load the pre-trained model.')
args = parser.parse_args()
# fix random seeds
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(args.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# create tokenizer
tokenizer = Tokenizer(args)
test_dataloader = R2DataLoader(args, tokenizer, split='test', shuffle=False)
# get function handles of loss and metrics
criterion = compute_loss
metrics = compute_scores
model = MedCapModel(args, tokenizer)
# build trainer and start to train
tester = Tester(model, criterion, metrics, args, test_dataloader)
tester.test()
if __name__ == '__main__':
main()