From fe833d5b44730d0ae4bc35885774c8a5beb99967 Mon Sep 17 00:00:00 2001 From: CKHui Date: Thu, 17 Dec 2020 17:26:35 +0800 Subject: [PATCH 1/7] add pna model --- cogdl/models/nn/pyg_pna.py | 136 +++++++++++++++++++++++++++++++++++++ match.yml | 1 + 2 files changed, 137 insertions(+) create mode 100644 cogdl/models/nn/pyg_pna.py diff --git a/cogdl/models/nn/pyg_pna.py b/cogdl/models/nn/pyg_pna.py new file mode 100644 index 00000000..c72ca742 --- /dev/null +++ b/cogdl/models/nn/pyg_pna.py @@ -0,0 +1,136 @@ +import random + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch.nn import ModuleList, Embedding +from torch.nn import Sequential, ReLU, Linear + +from .. import BaseModel, register_model +from cogdl.data import DataLoader + +from torch_geometric.nn import PNAConv, BatchNorm, global_add_pool + +@register_model("pyg_pna") +class PNA(BaseModel): + r"""Implements a single convolutional layer of the Principal Neighbourhood Aggregation Networks + in paper `"Principal Neighbourhood Aggregation for Graph Nets" .` + """ + @staticmethod + def add_args(parser): + parser.add_argument("--num_features", type=int) + parser.add_argument("--num_classes", type=int) + parser.add_argument("--hidden_size", type=int, default=60) + parser.add_argument("--avg_deg", type=int, default=1) + + parser.add_argument("--layer", type=int, default=4) + parser.add_argument("--pre_layers", type=int, default=1) + parser.add_argument("--towers", type=int, default=5) + parser.add_argument("--post_layers", type=int, default=1) + parser.add_argument("--edge_dim", type=int, default=None) + + parser.add_argument("--aggregators", type=str, nargs="+", default=['mean', 'min', 'max', 'std']) + parser.add_argument("--scalers", type=str, nargs="+", default=['identity', 'amplification', 'attenuation']) + + parser.add_argument("--divide_input", action='store_true', default=False) + + parser.add_argument("--batch-size", type=int, default=20) + parser.add_argument("--train-ratio", type=float, default=0.7) + parser.add_argument("--test-ratio", type=float, default=0.1) + + @classmethod + def build_model_from_args(cls, args): + return cls( + args.num_features, + args.num_classes, + args.hidden_size, + args.avg_deg, + args.layer, + args.pre_layers, + args.towers, + args.post_layers, + args.edge_dim, + args.aggregators, + args.scalers, + args.divide_input + ) + + @classmethod + def split_dataset(cls, dataset, args): + + random.shuffle(dataset) + train_size = int(len(dataset) * args.train_ratio) + test_size = int(len(dataset) * args.test_ratio) + bs = args.batch_size + train_loader = DataLoader(dataset[:train_size], batch_size=bs) + test_loader = DataLoader(dataset[-test_size:], batch_size=bs) + if args.train_ratio + args.test_ratio < 1: + valid_loader = DataLoader(dataset[train_size:-test_size], batch_size=bs) + else: + valid_loader = test_loader + + return train_loader, valid_loader, test_loader + + def __init__(self, num_feature, num_classes, hidden_size, avg_deg, + layer=4, pre_layers=1,towers=5, post_layers=1, + edge_dim=None, + aggregators=['mean', 'min', 'max', 'std'], + scalers = ['identity', 'amplification', 'attenuation'], + divide_input=False + ): + super(PNA, self).__init__() + + self.hidden_size = hidden_size + self.edge_dim = edge_dim + + avg_deg = torch.tensor(avg_deg) + emd_side = self.hidden_size // num_feature + self.hidden_size = emd_side * num_feature + self.node_emb = Embedding(num_feature, emd_side) + if not self.edge_dim is None: + self.edge_emb = Embedding(4, edge_dim) + + self.convs = ModuleList() + self.batch_norms = ModuleList() + for _ in range(layer): + conv = PNAConv(in_channels=self.hidden_size, out_channels=self.hidden_size, + aggregators=aggregators, scalers=scalers, deg=avg_deg, + edge_dim=edge_dim, towers=towers, pre_layers=pre_layers, post_layers=post_layers, + divide_input=divide_input) + self.convs.append(conv) + self.batch_norms.append(BatchNorm(self.hidden_size)) + + self.mlp = Sequential(Linear(self.hidden_size, self.hidden_size//2), ReLU(), Linear(self.hidden_size//2, self.hidden_size//4), ReLU(), + Linear(self.hidden_size//4, num_classes)) + + def forward(self, b): + x = b.x + edge_index = b.edge_index + edge_attr = b.edge_attr + batch_h = b.batch + n = x.shape[0] + + x = self.node_emb(x.long()) + x = x.reshape([n,-1]) + + if self.edge_dim is None: + edge_attr = None + else: + edge_attr = self.edge_emb(edge_attr) + + for conv, batch_norm in zip(self.convs, self.batch_norms): + x = F.relu(batch_norm(conv(x, edge_index, edge_attr))) + x = global_add_pool(x, batch_h) + out = self.mlp(x) + + if b.y is not None: + return out, self.loss(out, b.y) + print(loss) + return out, None + + def loss(self, prediction, label): + criterion = nn.CrossEntropyLoss() + loss_n = criterion(prediction, label) + return loss_n + diff --git a/match.yml b/match.yml index 2c1d7dec..db9ce96f 100644 --- a/match.yml +++ b/match.yml @@ -68,6 +68,7 @@ graph_classification: - patchy_san - hgpsl - sagpool + - pyg_pna dataset: - mutag - imdb-b From 4a2d09a8c35ea0902fe6086c5b3e0254785bc2b9 Mon Sep 17 00:00:00 2001 From: CKHui Date: Thu, 17 Dec 2020 17:38:12 +0800 Subject: [PATCH 2/7] add pna example --- examples/gnn_models/pna.py | 77 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 examples/gnn_models/pna.py diff --git a/examples/gnn_models/pna.py b/examples/gnn_models/pna.py new file mode 100644 index 00000000..85d5c03e --- /dev/null +++ b/examples/gnn_models/pna.py @@ -0,0 +1,77 @@ +import random +import numpy as np + +import torch + +from utils import print_result, set_random_seed, get_dataset +from cogdl.tasks import build_task +from cogdl.datasets import build_dataset +from cogdl.utils import build_args_from_dict + +DATASET_REGISTRY = {} + +def build_default_args_for_graph_classification(dataset): + cpu = not torch.cuda.is_available() + args = { + "lr": 0.001, + "weight_decay": 5e-4, + "max_epoch": 500, + "patience": 50, + "cpu": cpu, + "device_id": [0], + "seed": [0], + + "train_ratio": 0.7, + "test_ratio": 0.1, + "batch_size": 128, + "kfold": False, + "degree_feature": False, + + "hidden_size": 60, + "avg_deg": 1, + "layer": 4, + "pre_layers": 1, + "towers": 5, + "post_layers": 1, + "edge_dim": None, + "aggregators": ['mean', 'min', 'max', 'std'], + "scalers": ['identity', 'amplification', 'attenuation'], + "divide_input": False, + + "task": "graph_classification", + "model": "pyg_pna", + "dataset": dataset + } + return build_args_from_dict(args) + + +def register_func(name): + def register_func_name(func): + DATASET_REGISTRY[name] = func + return func + return register_func_name + + +@register_func("proteins") +def proteins_config(args): + return args + +def run(dataset_name): + args = build_default_args_for_graph_classification(dataset_name) + args = DATASET_REGISTRY[dataset_name](args) + dataset, args = get_dataset(args) + results = [] + for seed in args.seed: + set_random_seed(seed) + task = build_task(args, dataset=dataset) + result = task.train() + results.append(result) + return results + + +if __name__ == "__main__": + datasets = ["proteins"] + results = [] + for x in datasets: + results += run(x) + print_result(results, datasets, "pyg_pna") From 6d979b1c76df9419ffaa091df03d325c81e14ef5 Mon Sep 17 00:00:00 2001 From: CKHui Date: Thu, 17 Dec 2020 17:49:26 +0800 Subject: [PATCH 3/7] add pna test --- tests/tasks/test_graph_classification.py | 56 ++++++++++++++++++------ 1 file changed, 42 insertions(+), 14 deletions(-) diff --git a/tests/tasks/test_graph_classification.py b/tests/tasks/test_graph_classification.py index ac2ff930..c4d4a898 100644 --- a/tests/tasks/test_graph_classification.py +++ b/tests/tasks/test_graph_classification.py @@ -99,6 +99,22 @@ def add_sagpool_args(args): args.pooling_layer_type = "gcnconv" return args +def add_pna_args(args): + args.hidden_size = 60 + args.avg_deg = 3 + args.layer = 1 + args.train_ratio = 0.7 + args.test_ratio = 0.1 + args.pooling_ratio = 0.5 + args.pre_layers = 1 + args.towers = 1 + args.post_layers = 1 + args.edge_dim = None + args.aggregators = ['mean'] + args.scalers = ['identity'] + args.divide_input = False + return args + def test_gin_mutag(): args = get_default_args() args = add_gin_args(args) @@ -256,27 +272,39 @@ def test_sagpool_proteins(): ret = task.train() assert ret["Acc"] > 0 +def test_pna_proteins(): + args = get_default_args() + args = add_pna_args(args) + args.dataset = "proteins" + args.model = "pyg_pna" + args.batch_size = 20 + task = build_task(args) + ret = task.train() + assert ret["Acc"] > 0 + if __name__ == "__main__": - test_gin_imdb_binary() - test_gin_mutag() - test_gin_proteins() + # test_gin_imdb_binary() + # test_gin_mutag() + # test_gin_proteins() + + # test_sortpool_mutag() + # test_sortpool_proteins() - test_sortpool_mutag() - test_sortpool_proteins() + # test_diffpool_mutag() + # test_diffpool_proteins() - test_diffpool_mutag() - test_diffpool_proteins() + # test_dgcnn_proteins() + # test_dgcnn_imdb_binary() - test_dgcnn_proteins() - test_dgcnn_imdb_binary() + # test_patchy_san_mutag() + # test_patchy_san_proteins() - test_patchy_san_mutag() - test_patchy_san_proteins() + # test_hgpsl_proteins() - test_hgpsl_proteins() + # test_sagpool_mutag() + # test_sagpool_proteins() - test_sagpool_mutag() - test_sagpool_proteins() + test_pna_proteins() From dd11ba7ba6dfabecd37fe093dcd4641dd87cb2d1 Mon Sep 17 00:00:00 2001 From: CKHui Date: Thu, 17 Dec 2020 17:50:26 +0800 Subject: [PATCH 4/7] add pna test --- tests/tasks/test_graph_classification.py | 28 ++++++++++++------------ 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/tasks/test_graph_classification.py b/tests/tasks/test_graph_classification.py index c4d4a898..7344954d 100644 --- a/tests/tasks/test_graph_classification.py +++ b/tests/tasks/test_graph_classification.py @@ -285,26 +285,26 @@ def test_pna_proteins(): if __name__ == "__main__": - # test_gin_imdb_binary() - # test_gin_mutag() - # test_gin_proteins() + test_gin_imdb_binary() + test_gin_mutag() + test_gin_proteins() - # test_sortpool_mutag() - # test_sortpool_proteins() + test_sortpool_mutag() + test_sortpool_proteins() - # test_diffpool_mutag() - # test_diffpool_proteins() + test_diffpool_mutag() + test_diffpool_proteins() - # test_dgcnn_proteins() - # test_dgcnn_imdb_binary() + test_dgcnn_proteins() + test_dgcnn_imdb_binary() - # test_patchy_san_mutag() - # test_patchy_san_proteins() + test_patchy_san_mutag() + test_patchy_san_proteins() - # test_hgpsl_proteins() + test_hgpsl_proteins() - # test_sagpool_mutag() - # test_sagpool_proteins() + test_sagpool_mutag() + test_sagpool_proteins() test_pna_proteins() From b8d5596054b3e814e29c2d655ee950b80002aaea Mon Sep 17 00:00:00 2001 From: CKHui Date: Fri, 18 Dec 2020 17:32:25 +0800 Subject: [PATCH 5/7] refactor pna --- cogdl/models/nn/pyg_pna.py | 34 ++++++++---------------- examples/gnn_models/pna.py | 7 +---- tests/tasks/test_graph_classification.py | 1 - 3 files changed, 12 insertions(+), 30 deletions(-) diff --git a/cogdl/models/nn/pyg_pna.py b/cogdl/models/nn/pyg_pna.py index c72ca742..b0f96e07 100644 --- a/cogdl/models/nn/pyg_pna.py +++ b/cogdl/models/nn/pyg_pna.py @@ -1,15 +1,10 @@ import random - import torch import torch.nn as nn import torch.nn.functional as F - -from torch.nn import ModuleList, Embedding -from torch.nn import Sequential, ReLU, Linear - +from torch.nn import ModuleList, Embedding, Sequential, ReLU, Linear from .. import BaseModel, register_model from cogdl.data import DataLoader - from torch_geometric.nn import PNAConv, BatchNorm, global_add_pool @register_model("pyg_pna") @@ -29,12 +24,10 @@ def add_args(parser): parser.add_argument("--towers", type=int, default=5) parser.add_argument("--post_layers", type=int, default=1) parser.add_argument("--edge_dim", type=int, default=None) - parser.add_argument("--aggregators", type=str, nargs="+", default=['mean', 'min', 'max', 'std']) parser.add_argument("--scalers", type=str, nargs="+", default=['identity', 'amplification', 'attenuation']) - parser.add_argument("--divide_input", action='store_true', default=False) - + parser.add_argument("--batch-size", type=int, default=20) parser.add_argument("--train-ratio", type=float, default=0.7) parser.add_argument("--test-ratio", type=float, default=0.1) @@ -58,7 +51,6 @@ def build_model_from_args(cls, args): @classmethod def split_dataset(cls, dataset, args): - random.shuffle(dataset) train_size = int(len(dataset) * args.train_ratio) test_size = int(len(dataset) * args.test_ratio) @@ -77,10 +69,8 @@ def __init__(self, num_feature, num_classes, hidden_size, avg_deg, edge_dim=None, aggregators=['mean', 'min', 'max', 'std'], scalers = ['identity', 'amplification', 'attenuation'], - divide_input=False - ): + divide_input=False ): super(PNA, self).__init__() - self.hidden_size = hidden_size self.edge_dim = edge_dim @@ -101,9 +91,14 @@ def __init__(self, num_feature, num_classes, hidden_size, avg_deg, self.convs.append(conv) self.batch_norms.append(BatchNorm(self.hidden_size)) - self.mlp = Sequential(Linear(self.hidden_size, self.hidden_size//2), ReLU(), Linear(self.hidden_size//2, self.hidden_size//4), ReLU(), + self.mlp = Sequential(Linear(self.hidden_size, self.hidden_size//2), + ReLU(), + Linear(self.hidden_size//2, self.hidden_size//4), + ReLU(), Linear(self.hidden_size//4, num_classes)) + self.criterion = nn.CrossEntropyLoss() + def forward(self, b): x = b.x edge_index = b.edge_index @@ -125,12 +120,5 @@ def forward(self, b): out = self.mlp(x) if b.y is not None: - return out, self.loss(out, b.y) - print(loss) - return out, None - - def loss(self, prediction, label): - criterion = nn.CrossEntropyLoss() - loss_n = criterion(prediction, label) - return loss_n - + return out, self.criterion(out, b.y) + return out, None \ No newline at end of file diff --git a/examples/gnn_models/pna.py b/examples/gnn_models/pna.py index 85d5c03e..72ec36c9 100644 --- a/examples/gnn_models/pna.py +++ b/examples/gnn_models/pna.py @@ -1,8 +1,6 @@ import random import numpy as np - import torch - from utils import print_result, set_random_seed, get_dataset from cogdl.tasks import build_task from cogdl.datasets import build_dataset @@ -44,14 +42,12 @@ def build_default_args_for_graph_classification(dataset): } return build_args_from_dict(args) - def register_func(name): def register_func_name(func): DATASET_REGISTRY[name] = func return func return register_func_name - @register_func("proteins") def proteins_config(args): return args @@ -68,10 +64,9 @@ def run(dataset_name): results.append(result) return results - if __name__ == "__main__": datasets = ["proteins"] results = [] for x in datasets: results += run(x) - print_result(results, datasets, "pyg_pna") + print_result(results, datasets, "pyg_pna") \ No newline at end of file diff --git a/tests/tasks/test_graph_classification.py b/tests/tasks/test_graph_classification.py index 7344954d..8e3be392 100644 --- a/tests/tasks/test_graph_classification.py +++ b/tests/tasks/test_graph_classification.py @@ -301,7 +301,6 @@ def test_pna_proteins(): test_patchy_san_mutag() test_patchy_san_proteins() - test_hgpsl_proteins() test_sagpool_mutag() From 6a8e7e6d7f13bedad7d580cd19b9635da9cc94e6 Mon Sep 17 00:00:00 2001 From: CKHui Date: Fri, 18 Dec 2020 18:23:00 +0800 Subject: [PATCH 6/7] refactor --- cogdl/models/nn/pyg_pna.py | 25 ++++++++++++------------ examples/gnn_models/pna.py | 7 ++++++- tests/tasks/test_graph_classification.py | 15 ++++++++------ 3 files changed, 28 insertions(+), 19 deletions(-) diff --git a/cogdl/models/nn/pyg_pna.py b/cogdl/models/nn/pyg_pna.py index b0f96e07..61923d40 100644 --- a/cogdl/models/nn/pyg_pna.py +++ b/cogdl/models/nn/pyg_pna.py @@ -7,6 +7,7 @@ from cogdl.data import DataLoader from torch_geometric.nn import PNAConv, BatchNorm, global_add_pool + @register_model("pyg_pna") class PNA(BaseModel): r"""Implements a single convolutional layer of the Principal Neighbourhood Aggregation Networks @@ -31,7 +32,7 @@ def add_args(parser): parser.add_argument("--batch-size", type=int, default=20) parser.add_argument("--train-ratio", type=float, default=0.7) parser.add_argument("--test-ratio", type=float, default=0.1) - + @classmethod def build_model_from_args(cls, args): return cls( @@ -65,11 +66,11 @@ def split_dataset(cls, dataset, args): return train_loader, valid_loader, test_loader def __init__(self, num_feature, num_classes, hidden_size, avg_deg, - layer=4, pre_layers=1,towers=5, post_layers=1, + layer=4, pre_layers=1, towers=5, post_layers=1, edge_dim=None, aggregators=['mean', 'min', 'max', 'std'], - scalers = ['identity', 'amplification', 'attenuation'], - divide_input=False ): + scalers=['identity', 'amplification', 'attenuation'], + divide_input=False): super(PNA, self).__init__() self.hidden_size = hidden_size self.edge_dim = edge_dim @@ -78,7 +79,7 @@ def __init__(self, num_feature, num_classes, hidden_size, avg_deg, emd_side = self.hidden_size // num_feature self.hidden_size = emd_side * num_feature self.node_emb = Embedding(num_feature, emd_side) - if not self.edge_dim is None: + if self.edge_dim is not None: self.edge_emb = Embedding(4, edge_dim) self.convs = ModuleList() @@ -91,11 +92,11 @@ def __init__(self, num_feature, num_classes, hidden_size, avg_deg, self.convs.append(conv) self.batch_norms.append(BatchNorm(self.hidden_size)) - self.mlp = Sequential(Linear(self.hidden_size, self.hidden_size//2), + self.mlp = Sequential(Linear(self.hidden_size, self.hidden_size // 2), ReLU(), - Linear(self.hidden_size//2, self.hidden_size//4), + Linear(self.hidden_size // 2, self.hidden_size // 4), ReLU(), - Linear(self.hidden_size//4, num_classes)) + Linear(self.hidden_size // 4, num_classes)) self.criterion = nn.CrossEntropyLoss() @@ -107,8 +108,8 @@ def forward(self, b): n = x.shape[0] x = self.node_emb(x.long()) - x = x.reshape([n,-1]) - + x = x.reshape([n, -1]) + if self.edge_dim is None: edge_attr = None else: @@ -117,8 +118,8 @@ def forward(self, b): for conv, batch_norm in zip(self.convs, self.batch_norms): x = F.relu(batch_norm(conv(x, edge_index, edge_attr))) x = global_add_pool(x, batch_h) - out = self.mlp(x) + out = self.mlp(x) if b.y is not None: return out, self.criterion(out, b.y) - return out, None \ No newline at end of file + return out, None diff --git a/examples/gnn_models/pna.py b/examples/gnn_models/pna.py index 72ec36c9..bbd1f92c 100644 --- a/examples/gnn_models/pna.py +++ b/examples/gnn_models/pna.py @@ -8,6 +8,7 @@ DATASET_REGISTRY = {} + def build_default_args_for_graph_classification(dataset): cpu = not torch.cuda.is_available() args = { @@ -42,16 +43,19 @@ def build_default_args_for_graph_classification(dataset): } return build_args_from_dict(args) + def register_func(name): def register_func_name(func): DATASET_REGISTRY[name] = func return func return register_func_name + @register_func("proteins") def proteins_config(args): return args + def run(dataset_name): args = build_default_args_for_graph_classification(dataset_name) args = DATASET_REGISTRY[dataset_name](args) @@ -64,9 +68,10 @@ def run(dataset_name): results.append(result) return results + if __name__ == "__main__": datasets = ["proteins"] results = [] for x in datasets: results += run(x) - print_result(results, datasets, "pyg_pna") \ No newline at end of file + print_result(results, datasets, "pyg_pna") diff --git a/tests/tasks/test_graph_classification.py b/tests/tasks/test_graph_classification.py index a417b465..a0fba30d 100644 --- a/tests/tasks/test_graph_classification.py +++ b/tests/tasks/test_graph_classification.py @@ -100,6 +100,7 @@ def add_sagpool_args(args): args.pooling_layer_type = "gcnconv" return args + def add_pna_args(args): args.hidden_size = 60 args.avg_deg = 3 @@ -107,15 +108,16 @@ def add_pna_args(args): args.train_ratio = 0.7 args.test_ratio = 0.1 args.pooling_ratio = 0.5 - args.pre_layers = 1 - args.towers = 1 - args.post_layers = 1 - args.edge_dim = None - args.aggregators = ['mean'] + args.pre_layers = 1 + args.towers = 1 + args.post_layers = 1 + args.edge_dim = None + args.aggregators = ['mean'] args.scalers = ['identity'] - args.divide_input = False + args.divide_input = False return args + def test_gin_mutag(): args = get_default_args() args = add_gin_args(args) @@ -273,6 +275,7 @@ def test_sagpool_proteins(): ret = task.train() assert ret["Acc"] > 0 + def test_pna_proteins(): args = get_default_args() args = add_pna_args(args) From 12fcd2b30ab68d2c2cc330e3744191722993934c Mon Sep 17 00:00:00 2001 From: CKHui Date: Fri, 18 Dec 2020 20:51:22 +0800 Subject: [PATCH 7/7] refactor --- cogdl/models/nn/pyg_pna.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/cogdl/models/nn/pyg_pna.py b/cogdl/models/nn/pyg_pna.py index 61923d40..36b01c7e 100644 --- a/cogdl/models/nn/pyg_pna.py +++ b/cogdl/models/nn/pyg_pna.py @@ -19,7 +19,6 @@ def add_args(parser): parser.add_argument("--num_classes", type=int) parser.add_argument("--hidden_size", type=int, default=60) parser.add_argument("--avg_deg", type=int, default=1) - parser.add_argument("--layer", type=int, default=4) parser.add_argument("--pre_layers", type=int, default=1) parser.add_argument("--towers", type=int, default=5) @@ -28,7 +27,6 @@ def add_args(parser): parser.add_argument("--aggregators", type=str, nargs="+", default=['mean', 'min', 'max', 'std']) parser.add_argument("--scalers", type=str, nargs="+", default=['identity', 'amplification', 'attenuation']) parser.add_argument("--divide_input", action='store_true', default=False) - parser.add_argument("--batch-size", type=int, default=20) parser.add_argument("--train-ratio", type=float, default=0.7) parser.add_argument("--test-ratio", type=float, default=0.1) @@ -62,7 +60,6 @@ def split_dataset(cls, dataset, args): valid_loader = DataLoader(dataset[train_size:-test_size], batch_size=bs) else: valid_loader = test_loader - return train_loader, valid_loader, test_loader def __init__(self, num_feature, num_classes, hidden_size, avg_deg,