-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_count.py
130 lines (111 loc) · 5.85 KB
/
train_count.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
script to train on counting substructure tasks.
"""
from datasets.GraphCountDataset import GraphCountDatasetI2
import torch
import torch.nn as nn
from models.input_encoder import EmbeddingEncoder
import train_utils
from interfaces.pl_model_interface import PlGNNTestonValModule
from interfaces.pl_data_interface import PlPyGDataTestonValModule
from lightning.pytorch import seed_everything
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor, Timer
from lightning.pytorch.callbacks.progress import TQDMProgressBar
import wandb
from torchmetrics import MeanAbsoluteError
import torch_geometric.transforms as T
from torch_geometric.data import Data
def add_node_feature(data: Data) -> Data:
r"""Add identical initial node feature to all graphs.
Arg:
data (Data): PyG data.
"""
data.x = torch.zeros([data.num_nodes, 1]).long()
return data
def main():
parser = train_utils.args_setup()
parser.add_argument('--dataset_name', type=str, default="count_cycle", choices=("count_cycle", "count_graphlet"),
help='Name of dataset.')
parser.add_argument('--task', type=int, default=0, choices=(0, 1, 2, 3, 4), help='Train task index.')
parser.add_argument('--runs', type=int, default=3, help='Number of repeat run.')
args = parser.parse_args()
args = train_utils.update_args(args)
path, pre_transform, follow_batch = train_utils.data_setup(args)
train_dataset = GraphCountDatasetI2(root=path,
dataname=args.dataset_name,
split="train",
pre_transform=T.Compose([add_node_feature, pre_transform]),
transform=train_utils.PostTransform(args.wo_node_feature,
args.wo_edge_feature,
args.task))
val_dataset = GraphCountDatasetI2(root=path,
dataname=args.dataset_name,
split="val",
pre_transform=T.Compose([add_node_feature, pre_transform]),
transform=train_utils.PostTransform(args.wo_node_feature,
args.wo_edge_feature,
args.task))
test_dataset = GraphCountDatasetI2(root=path,
dataname=args.dataset_name,
split="test",
pre_transform=T.Compose([add_node_feature, pre_transform]),
transform=train_utils.PostTransform(args.wo_node_feature,
args.wo_edge_feature,
args.task))
y_train_val = torch.cat([train_dataset.data.y, val_dataset.data.y], dim=0)
mean = y_train_val.mean(dim=0)
std = y_train_val.std(dim=0)
train_dataset.data.y = (train_dataset.data.y - mean) / std
val_dataset.data.y = (val_dataset.data.y - mean) / std
test_dataset.data.y = (test_dataset.data.y - mean) / std
for i in range(1, args.runs + 1):
logger = WandbLogger(name=f'run_{str(i)}',
project=args.exp_name,
save_dir=args.save_dir,
offline=args.offline)
logger.log_hyperparams(args)
timer = Timer(duration=dict(weeks=4))
# Set random seed
seed = train_utils.get_seed(args.seed)
seed_everything(seed)
datamodule = PlPyGDataTestonValModule(train_dataset=train_dataset,
val_dataset=val_dataset,
test_dataset=test_dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
follow_batch=follow_batch)
loss_cri = nn.L1Loss()
evaluator = MeanAbsoluteError()
args.mode = "min"
init_encoder = EmbeddingEncoder(2, args.hidden_channels)
modelmodule = PlGNNTestonValModule(loss_criterion=loss_cri,
evaluator=evaluator,
args=args,
init_encoder=init_encoder)
trainer = Trainer(
accelerator="auto",
devices="auto",
max_epochs=args.num_epochs,
enable_checkpointing=True,
enable_progress_bar=True,
logger=logger,
callbacks=[
TQDMProgressBar(refresh_rate=20),
ModelCheckpoint(monitor="val/metric", mode=args.mode),
LearningRateMonitor(logging_interval="epoch"),
timer
]
)
trainer.fit(modelmodule, datamodule=datamodule)
val_result, test_result = trainer.test(modelmodule, datamodule=datamodule, ckpt_path="best")
results = {"final/best_val_metric": val_result["val/metric"],
"final/best_test_metric": test_result["test/metric"],
"final/avg_train_time_epoch": timer.time_elapsed("train") / args.num_epochs,
}
logger.log_metrics(results)
wandb.finish()
return
if __name__ == "__main__":
main()