This repository has been archived by the owner on Mar 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_and_eval.py
148 lines (124 loc) · 5.02 KB
/
train_and_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import copy
from tqdm import tqdm
import torch
import dgl
import torch.nn.functional as F
def get_prototypes(model, dataloader, conf):
model.prototypes = torch.zeros(conf["num_classes"], conf["label_dim"]).to(conf["device"])
for input_nodes, output_nodes, blocks in dataloader:
input_features = blocks[0].srcdata['feat']
output_labels = blocks[-1].dstdata['label']
logits = model(blocks, input_features * model.p)
num_batch_nodes = len(output_nodes)
mask = torch.zeros(conf["num_classes"], num_batch_nodes).to(conf["device"])
mask[output_labels, torch.arange(num_batch_nodes)] = 1
sum_features_for_each_label = mask @ logits
model.prototypes += sum_features_for_each_label
model.prototypes = F.normalize(model.prototypes)
def train(model, dataloader, criterion, evaluator, optimizer, conf):
model.train()
total_loss, total_logits, total_labels = 0, [], []
dataloader_tqdm = tqdm(dataloader, ncols=120)
for input_nodes, output_nodes, blocks in dataloader_tqdm:
get_prototypes(model, dataloader, conf)
input_features = blocks[0].srcdata['feat']
output_labels = blocks[-1].dstdata['label']
logits = model(blocks, input_features * model.p)
loss = criterion(model, logits, output_labels)
total_loss += loss.item()
total_logits.append(logits)
total_labels.append(output_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
dataloader_tqdm.set_description(f'total training loss: {total_loss}')
total_logits = torch.cat(total_logits, dim=0)
total_labels = torch.cat(total_labels, dim=0)
score = evaluator(model, total_logits, total_labels)
return total_loss / len(dataloader), score
@torch.no_grad()
def evaluate(model, dataloader, criterion, evaluator, conf):
model.eval() # Set model to evaluation mode
total_loss, total_logits, total_labels = 0, [], []
get_prototypes(model, dataloader, conf)
dataloader_tqdm = tqdm(dataloader, ncols=120)
for input_nodes, output_nodes, blocks in dataloader_tqdm:
input_features = blocks[0].srcdata['feat']
output_labels = blocks[-1].dstdata['label']
logits = model(blocks, input_features * model.p, inference=True)
loss = criterion(model, logits, output_labels)
total_loss += loss.item()
total_logits.append(logits)
total_labels.append(output_labels)
dataloader_tqdm.set_description(f'total evaluation loss: {total_loss}')
total_logits = torch.cat(total_logits, dim=0)
total_labels = torch.cat(total_labels, dim=0)
score = evaluator(model, total_logits, total_labels)
return total_loss / len(dataloader), score
def run_transductive(
conf,
model,
g,
feats,
labels,
indices,
criterion,
evaluator,
optimizer,
logger,
loss_and_score,
):
"""
Train and eval under the transductive setting.
The train/valid/test split is specified by `indices`.
The input graph is assumed to be large. Thus, SAGE is used for GNNs, mini-batch is used for MLPs.
loss_and_score: Stores losses and scores.
"""
sampler = dgl.dataloading.NeighborSampler([eval(fanout) for fanout in conf["fan_out"].split(",")])
full_sampler = dgl.dataloading.MultiLayerFullNeighborSampler(model.num_layers)
dataloader_train = dgl.dataloading.DataLoader(
g,
indices[0],
sampler,
device=conf["device"],
batch_size=conf["batch_size"],
use_uva=True
)
dataloader_val = dgl.dataloading.DataLoader(
g,
indices[1],
full_sampler,
device=conf["device"],
batch_size=conf["batch_size"],
use_uva=True
)
dataloader_test = dgl.dataloading.DataLoader(
g,
indices[2],
full_sampler,
device=conf["device"],
batch_size=conf["batch_size"],
use_uva=True
)
best_epoch, best_score_val, count = 0, 0, 0
for epoch in range(1, conf["max_epoch"] + 1):
loss, score = train(model, dataloader_train, criterion, evaluator, optimizer, conf)
logger.debug(f"Ep {epoch:3d} | loss_train: {loss:.4f} | acc_train: {score:.4f}")
if epoch % conf["eval_interval"] == 0:
loss, score = evaluate(model, dataloader_val, criterion, evaluator, conf)
logger.debug(f"Ep {epoch:3d} | loss_val: {loss:.4f} | acc_val: {score:.4f}")
if score >= best_score_val:
best_epoch = epoch
best_score_val = score
state = copy.deepcopy(model.state_dict())
count = 0
else:
count += 1
if count == conf["patience"] or epoch == conf["max_epoch"]:
break
model.load_state_dict(state)
loss, score = evaluate(model, dataloader_test, criterion, evaluator, conf)
logger.info(f"Best valid model at epoch: {best_epoch: 3d}, score_val: {best_score_val :.4f}, score_test: {score :.4f}")
return None, loss, score
def run_inductive(*args, **kwargs):
pass