-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate_models_utils.py
98 lines (85 loc) · 6.06 KB
/
evaluate_models_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils.metrics import get_classification_metrics
from utils.utils import NeighborSampler
from utils.DataLoader import Data
def evaluate_model_classification(model_name: str, model: nn.Module, neighbor_sampler: NeighborSampler, evaluate_idx_data_loader: DataLoader,
evaluate_data: Data, loss_func: nn.Module, num_neighbors: int = 20, time_gap: int = 2000, fp='/dev/null'):
"""
evaluate models on the classification task
:param model_name: str, name of the model
:param model: nn.Module, the model to be evaluated
:param neighbor_sampler: NeighborSampler, neighbor sampler
:param evaluate_idx_data_loader: DataLoader, evaluate index data loader
:param evaluate_data: Data, data to be evaluated
:param loss_func: nn.Module, loss function
:param num_neighbors: int, number of neighbors to sample for each node
:param time_gap: int, time gap for neighbors to compute node features
:return:
"""
if model_name in ['DyRep', 'TGAT', 'TGN', 'CAWN', 'TCL', 'GraphMixer', 'DyGFormer']:
# evaluation phase use all the graph information
model[0].set_neighbor_sampler(neighbor_sampler)
model.eval()
multiclass = isinstance(loss_func, nn.CrossEntropyLoss)
with torch.no_grad():
# store evaluate losses, trues and predicts
model[1].prototypical_encoding(model[0])
evaluate_total_loss, evaluate_y_trues, evaluate_y_predicts = 0.0, [], []
evaluate_idx_data_loader_tqdm = tqdm(evaluate_idx_data_loader, ncols=120)
for batch_idx, evaluate_data_indices in enumerate(evaluate_idx_data_loader_tqdm):
batch_src_node_ids, batch_dst_node_ids, batch_node_interact_times, batch_edge_ids, batch_labels = \
evaluate_data.src_node_ids[evaluate_data_indices], evaluate_data.dst_node_ids[evaluate_data_indices], \
evaluate_data.node_interact_times[evaluate_data_indices], evaluate_data.edge_ids[evaluate_data_indices], evaluate_data.labels[evaluate_data_indices]
if model_name in ['TGAT', 'CAWN', 'TCL']:
# get temporal embedding of source and destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_src_node_embeddings, batch_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,
dst_node_ids=batch_dst_node_ids,
node_interact_times=batch_node_interact_times,
num_neighbors=num_neighbors)
elif model_name in ['JODIE', 'DyRep', 'TGN']:
# get temporal embedding of source and destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_src_node_embeddings, batch_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,
dst_node_ids=batch_dst_node_ids,
node_interact_times=batch_node_interact_times,
edge_ids=batch_edge_ids,
edges_are_positive=True,
num_neighbors=num_neighbors)
elif model_name in ['GraphMixer']:
# get temporal embedding of source and destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_src_node_embeddings, batch_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,
dst_node_ids=batch_dst_node_ids,
node_interact_times=batch_node_interact_times,
num_neighbors=num_neighbors,
time_gap=time_gap)
elif model_name in ['DyGFormer']:
# get temporal embedding of source and destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_src_node_embeddings, batch_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,
dst_node_ids=batch_dst_node_ids,
node_interact_times=batch_node_interact_times)
else:
raise ValueError(f"Wrong value for model_name {model_name}!")
# get predicted probabilities, shape (batch_size, )
predicts = model[1](input_1=batch_src_node_embeddings, input_2=batch_dst_node_embeddings, times=batch_node_interact_times)
labels = torch.from_numpy(batch_labels).to(predicts.device)
if not multiclass: labels = labels.float()
loss = loss_func(input=predicts, target=labels)
evaluate_total_loss += loss.item()
evaluate_y_trues.append(labels)
evaluate_y_predicts.append(predicts)
evaluate_idx_data_loader_tqdm.set_description(f'evaluate for the {batch_idx + 1}-th batch, evaluate loss: {loss.item()}')
evaluate_total_loss /= (batch_idx + 1)
evaluate_y_trues = torch.cat(evaluate_y_trues, dim=0)
evaluate_y_predicts = torch.cat(evaluate_y_predicts, dim=0)
evaluate_metrics = get_classification_metrics(predicts=evaluate_y_predicts, labels=evaluate_y_trues, multiclass=multiclass, fp=fp)
return evaluate_total_loss, evaluate_metrics