This repository has been archived by the owner on Mar 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
171 lines (135 loc) · 5.34 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import numpy as np
import torch
import torch.nn.functional as F
import logging
import pytz
import random
import os
import yaml
import shutil
from datetime import datetime
import dgl
from dgl import function as fn
CPF_data = ["cora", "citeseer", "pubmed", "a-computer", "a-photo"]
OGB_data = ["ogbn-arxiv", "ogbn-products"]
NonHom_data = ["pokec", "penn94"]
BGNN_data = ["house_class", "vk_class"]
def set_seed(seed):
torch.manual_seed(seed)
np.random.seed(seed)
dgl.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
def get_training_config(config_path, model_name, dataset):
with open(config_path, "r") as conf:
full_config = yaml.load(conf, Loader=yaml.FullLoader)
dataset_specific_config = full_config["global"]
model_specific_config = full_config[dataset][model_name]
if model_specific_config is not None:
specific_config = dict(dataset_specific_config, **model_specific_config)
else:
specific_config = dataset_specific_config
specific_config["model_name"] = model_name
return specific_config
def check_writable(path, overwrite=True):
if not os.path.exists(path):
os.makedirs(path)
elif overwrite:
shutil.rmtree(path)
os.makedirs(path)
else:
pass
def check_readable(path):
if not os.path.exists(path):
raise ValueError(f"No such file or directory! {path}")
def timetz(*args):
tz = pytz.timezone("US/Pacific")
return datetime.now(tz).timetuple()
def get_logger(filename, console_log=False, log_level=logging.INFO):
tz = pytz.timezone("US/Pacific")
log_time = datetime.now(tz).strftime("%b%d_%H_%M_%S")
logger = logging.getLogger(__name__)
logger.propagate = False # avoid duplicate logging
logger.setLevel(log_level)
# Clean logger first to avoid duplicated handlers
for hdlr in logger.handlers[:]:
logger.removeHandler(hdlr)
file_handler = logging.FileHandler(filename)
formatter = logging.Formatter("%(asctime)s: %(message)s", datefmt="%b%d %H-%M-%S")
formatter.converter = timetz
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
if console_log:
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
return logger
def idx_split(idx, ratio, seed=0):
"""
randomly split idx into two portions with ratio% elements and (1 - ratio)% elements
"""
set_seed(seed)
n = len(idx)
cut = int(n * ratio)
idx_idx_shuffle = torch.randperm(n)
idx1_idx, idx2_idx = idx_idx_shuffle[:cut], idx_idx_shuffle[cut:]
idx1, idx2 = idx[idx1_idx], idx[idx2_idx]
# assert((torch.cat([idx1, idx2]).sort()[0] == idx.sort()[0]).all())
return idx1, idx2
def graph_split(idx_train, idx_val, idx_test, rate, seed):
"""
Args:
The original setting was transductive. Full graph is observed, and idx_train takes up a small portion.
Split the graph by further divide idx_test into [idx_test_tran, idx_test_ind].
rate = idx_test_ind : idx_test (how much test to hide for the inductive evaluation)
Ex. Ogbn-products
loaded : train : val : test = 8 : 2 : 90, rate = 0.2
after split: train : val : test_tran : test_ind = 8 : 2 : 72 : 18
Return:
Indices start with 'obs_' correspond to the node indices within the observed subgraph,
where as indices start directly with 'idx_' correspond to the node indices in the original graph
"""
idx_test_ind, idx_test_tran = idx_split(idx_test, rate, seed)
idx_obs = torch.cat([idx_train, idx_val, idx_test_tran])
N1, N2 = idx_train.shape[0], idx_val.shape[0]
obs_idx_all = torch.arange(idx_obs.shape[0])
obs_idx_train = obs_idx_all[:N1]
obs_idx_val = obs_idx_all[N1 : N1 + N2]
obs_idx_test = obs_idx_all[N1 + N2 :]
return obs_idx_train, obs_idx_val, obs_idx_test, idx_obs, idx_test_ind
def get_evaluator(dataset, baseline=False):
def evaluator(model, logits, labels):
pred = (logits @ model.prototypes.T).argmax(dim=1)
return pred.eq(labels).float().mean().item()
def evaluator_balseline(model, logits, labels):
pred = logits.argmax(dim=1)
return pred.eq(labels).float().mean().item()
return evaluator_balseline if baseline else evaluator
def compute_min_cut_loss(g, out):
out = out.to("cpu")
S = out.exp()
A = g.adj().to_dense()
D = g.in_degrees().float().diag()
min_cut = (
torch.matmul(torch.matmul(S.transpose(1, 0), A), S).trace()
/ torch.matmul(torch.matmul(S.transpose(1, 0), D), S).trace()
)
return min_cut.item()
def feature_prop(feats, g, k):
"""
Augment node feature by propagating the node features within k-hop neighborhood.
The propagation is done in the SGC fashion, i.e. hop by hop and symmetrically normalized by node degrees.
"""
assert feats.shape[0] == g.num_nodes()
degs = g.in_degrees().float().clamp(min=1)
norm = torch.pow(degs, -0.5).unsqueeze(1)
# compute (D^-1/2 A D^-1/2)^k X
for _ in range(k):
feats = feats * norm
g.ndata["h"] = feats
g.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
feats = g.ndata.pop("h")
feats = feats * norm
return feats