-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
106 lines (92 loc) · 3.21 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import logging
import json
import torch
from torch.utils.data import DataLoader
import hydra
from omegaconf import DictConfig, OmegaConf
from mtb.data import TACREDDataset, SemEvalDataset, SmilerDataset
from mtb.model import MTBModel
from mtb.processor import BatchTokenizer, aggregate_batch
from mtb.train_eval import train_and_eval
from mtb.utils import resolve_relative_path, seed_everything
logger = logging.getLogger(__name__)
@hydra.main(config_name="config", config_path="configs")
def main(cfg: DictConfig) -> None:
"""
Conducts evaluation given the configuration.
Args:
cfg: Hydra-format configuration given in a dict.
"""
resolve_relative_path(cfg)
print(OmegaConf.to_yaml(cfg))
seed_everything(cfg.seed)
# prepare dataset: parse raw dataset and do some simple pre-processing such as
# convert special tokens and insert entity markers
entity_marker = True if cfg.variant in ["d", "e", "f"] else False
if "tacred" in cfg.train_file.lower():
train_dataset = TACREDDataset(cfg.train_file, entity_marker=entity_marker)
eval_dataset = TACREDDataset(cfg.eval_file, entity_marker=entity_marker)
layer_norm = False
elif "semeval" in cfg.train_file.lower():
train_dataset = SemEvalDataset(cfg.train_file, entity_marker=entity_marker)
eval_dataset = SemEvalDataset(cfg.eval_file, entity_marker=entity_marker)
layer_norm = True
elif "smiler" in cfg.train_file.lower():
train_dataset = SmilerDataset(cfg.train_file, entity_marker=entity_marker)
eval_dataset = SmilerDataset(cfg.eval_file, entity_marker=entity_marker)
layer_norm = True
label_to_id = train_dataset.label_to_id
# set dataloader
train_loader = DataLoader(
train_dataset,
batch_size=cfg.batch_size,
shuffle=True,
pin_memory=True,
collate_fn=aggregate_batch,
)
eval_loader = DataLoader(
eval_dataset,
batch_size=cfg.batch_size,
shuffle=False,
pin_memory=True,
collate_fn=aggregate_batch,
)
# set a processor that tokenizes and aligns all the tokens in a batch
batch_processor = BatchTokenizer(
tokenizer_name_or_path=cfg.model,
variant=cfg.variant,
max_length=cfg.max_length,
)
vocab_size = len(batch_processor.tokenizer)
# set model and device
model = MTBModel(
encoder_name_or_path=cfg.model,
variant=cfg.variant,
layer_norm=layer_norm,
vocab_size=vocab_size,
num_classes=len(label_to_id),
dropout=cfg.dropout,
)
device = (
torch.device("cuda", cfg.cuda_device)
if cfg.cuda_device > -1
else torch.device("cpu")
)
micro_f1, macro_f1 = train_and_eval(
model,
train_loader,
eval_loader,
label_to_id,
batch_processor,
num_epochs=cfg.num_epochs,
lr=cfg.lr,
device=device,
)
logger.info(
"Evaluation micro-F1: {:.4f}, macro_f1: {:.4f}.".format(micro_f1, macro_f1)
)
# save evaluation results to json
with open("./results.json", "w") as f:
json.dump({"micro_f1": micro_f1, "macro_f1": macro_f1}, f, indent=4)
if __name__ == "__main__":
main()