From a80ffad465bc660e3326b891ab4af964d9fc1c7a Mon Sep 17 00:00:00 2001 From: Yunfei Li Date: Sat, 12 Dec 2020 18:30:18 +0800 Subject: [PATCH 1/4] Dropedge works --- cogdl/models/nn/gcn.py | 85 ++++++++++++++++++++++++++++++++---------- 1 file changed, 65 insertions(+), 20 deletions(-) diff --git a/cogdl/models/nn/gcn.py b/cogdl/models/nn/gcn.py index 2ca5abc5..cc6af790 100644 --- a/cogdl/models/nn/gcn.py +++ b/cogdl/models/nn/gcn.py @@ -40,7 +40,7 @@ def forward(self, input, edge_index, edge_attr=None): (input.shape[0], input.shape[0]), ).to(input.device) support = torch.mm(input, self.weight) - output = torch.spmm(adj, support) + output = torch.sparse.mm(adj, support) if self.bias is not None: return output + self.bias else: @@ -57,6 +57,21 @@ def __repr__(self): ) +def drop_edge(adj, adj_values, dropedge_rate, train): + if train: + import numpy as np + n_edge = adj.shape[1] + _remaining_edges = np.arange(n_edge) + np.random.shuffle(_remaining_edges) + _remaining_edges = np.sort(_remaining_edges[:int((1 - dropedge_rate) * n_edge)]) + new_adj = adj[:, _remaining_edges] + new_adj_values = adj_values[_remaining_edges] + else: + new_adj = adj + new_adj_values = adj_values + return new_adj, new_adj_values + + @register_model("gcn") class TKipfGCN(BaseModel): r"""The GCN model from the `"Semi-Supervised Classification with Graph Convolutional Networks" @@ -77,39 +92,69 @@ def add_args(parser): parser.add_argument("--num-classes", type=int) parser.add_argument("--hidden-size", type=int, default=64) parser.add_argument("--dropout", type=float, default=0.5) + parser.add_argument("--num-layers", type=int, default=2) + parser.add_argument("--dropedge", type=float, default=0.0) # fmt: on @classmethod def build_model_from_args(cls, args): - return cls(args.num_features, args.hidden_size, args.num_classes, args.dropout) + return cls(args.num_features, args.hidden_size, args.num_classes, args.dropout, args.num_layers, args.dropedge) - def __init__(self, nfeat, nhid, nclass, dropout): + def __init__(self, nfeat, nhid, nclass, dropout, num_layers, dropedge): super(TKipfGCN, self).__init__() - self.gc1 = GraphConvolution(nfeat, nhid) - self.gc2 = GraphConvolution(nhid, nclass) + # self.gc1 = GraphConvolution(nfeat, nhid) + # self.gc2 = GraphConvolution(nhid, nclass) + self.gcs = nn.ModuleList() + self.gcs.append(GraphConvolution(nfeat, nhid)) + for _ in range(num_layers - 2): + self.gcs.append(GraphConvolution(nhid, nhid)) + self.gcs.append(GraphConvolution(nhid, nclass)) self.dropout = dropout + self.dropedge = dropedge # 0 correspond to no dropedge # self.nonlinear = nn.SELU() def forward(self, x, adj): device = x.device adj_values = torch.ones(adj.shape[1]).to(device) - adj, adj_values = add_remaining_self_loops(adj, adj_values, 1, x.shape[0]) - deg = spmm(adj, adj_values, torch.ones(x.shape[0], 1).to(device)).squeeze() - deg_sqrt = deg.pow(-1/2) - adj_values = deg_sqrt[adj[1]] * adj_values * deg_sqrt[adj[0]] - - x = F.dropout(x, self.dropout, training=self.training) - x = F.relu(self.gc1(x, adj, adj_values)) - # h1 = x - x = F.dropout(x, self.dropout, training=self.training) - x = self.gc2(x, adj, adj_values) - - # x = F.relu(x) - # x = torch.sigmoid(x) - # return x - # h2 = x + original_adj = adj + original_adj_values = adj_values + original_x_shape = x.shape[0] + for idx, gc_layer in enumerate(self.gcs): + adj, adj_values = drop_edge(original_adj, original_adj_values, self.dropedge, self.training) + adj, adj_values = add_remaining_self_loops(adj, adj_values, 1, original_x_shape) + deg = spmm(adj, adj_values, torch.ones(original_x_shape, 1).to(device)).squeeze() + deg_sqrt = deg.pow(-1 / 2) + adj_values = deg_sqrt[adj[1]] * adj_values * deg_sqrt[adj[0]] + x = F.dropout(x, self.dropout, training=self.training) + x = gc_layer(x, adj, adj_values) + if idx != len(self.gcs) - 1: + x = F.relu(x) return F.log_softmax(x, dim=-1) + + # # TODO: implement drop edge here. + # adj, adj_values = drop_edge(original_adj, original_adj_values, self.dropedge, self.training) + # adj, adj_values = add_remaining_self_loops(adj, adj_values, 1, x.shape[0]) + # deg = spmm(adj, adj_values, torch.ones(x.shape[0], 1).to(device)).squeeze() + # deg_sqrt = deg.pow(-1/2) + # adj_values = deg_sqrt[adj[1]] * adj_values * deg_sqrt[adj[0]] + # + # x = F.dropout(x, self.dropout, training=self.training) + # x = F.relu(self.gc1(x, adj, adj_values)) + # # h1 = x + # x = F.dropout(x, self.dropout, training=self.training) + # adj, adj_values = drop_edge(original_adj, original_adj_values, self.dropedge, self.training) + # adj, adj_values = add_remaining_self_loops(adj, adj_values, 1, x.shape[0]) + # deg = spmm(adj, adj_values, torch.ones(x.shape[0], 1).to(device)).squeeze() + # deg_sqrt = deg.pow(-1 / 2) + # adj_values = deg_sqrt[adj[1]] * adj_values * deg_sqrt[adj[0]] + # x = self.gc2(x, adj, adj_values) + # + # # x = F.relu(x) + # # x = torch.sigmoid(x) + # # return x + # # h2 = x + # return F.log_softmax(x, dim=-1) def loss(self, data): return F.nll_loss( From af9bc14da9373ece497e49b353dbebb2ec91e9b6 Mon Sep 17 00:00:00 2001 From: Yunfei Li Date: Mon, 14 Dec 2020 17:29:59 +0800 Subject: [PATCH 2/4] Add support to some tricks --- cogdl/models/nn/gcn.py | 131 +++++++++++++++++++++++++++-------------- environment.yml | 100 +++++++++++++++++++++++++++++++ 2 files changed, 186 insertions(+), 45 deletions(-) create mode 100644 environment.yml diff --git a/cogdl/models/nn/gcn.py b/cogdl/models/nn/gcn.py index cc6af790..369bab9a 100644 --- a/cogdl/models/nn/gcn.py +++ b/cogdl/models/nn/gcn.py @@ -6,7 +6,7 @@ from torch.nn.parameter import Parameter from .. import BaseModel, register_model -from cogdl.utils import add_remaining_self_loops, spmm, spmm_adj +from cogdl.utils import add_remaining_self_loops, spmm, spmm_adj, add_self_loops class GraphConvolution(nn.Module): @@ -14,11 +14,19 @@ class GraphConvolution(nn.Module): Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 """ - def __init__(self, in_features, out_features, bias=True): + def __init__(self, in_features, out_features, bias=True, withloop=False, withbn=False): super(GraphConvolution, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.FloatTensor(in_features, out_features)) + if withloop: + self.self_weight = Parameter(torch.FloatTensor(in_features, out_features)) + else: + self.register_parameter("self_weight", None) + if withbn: + self.bn = torch.nn.BatchNorm1d(out_features) + else: + self.register_parameter("bn", None) if bias: self.bias = Parameter(torch.FloatTensor(out_features)) else: @@ -28,6 +36,9 @@ def __init__(self, in_features, out_features, bias=True): def reset_parameters(self): stdv = 1.0 / math.sqrt(self.weight.size(1)) self.weight.data.normal_(-stdv, stdv) + if self.self_weight is not None: + stdv = 1. / math.sqrt(self.self_weight.size(1)) + self.self_weight.data.uniform_(-stdv, stdv) if self.bias is not None: self.bias.data.normal_(-stdv, stdv) @@ -41,10 +52,14 @@ def forward(self, input, edge_index, edge_attr=None): ).to(input.device) support = torch.mm(input, self.weight) output = torch.sparse.mm(adj, support) + # Self-loop + if self.self_weight is not None: + output = output + torch.mm(input, self.self_weight) if self.bias is not None: - return output + self.bias - else: - return output + output = output + self.bias + if self.bn is not None: + output = self.bn(output) + return output def __repr__(self): return ( @@ -72,6 +87,33 @@ def drop_edge(adj, adj_values, dropedge_rate, train): return new_adj, new_adj_values +def bingge_norm_adj(adj, adj_values, num_nodes): + adj, adj_values = add_self_loops(adj, adj_values, 1, num_nodes) + deg = spmm(adj, adj_values, torch.ones(num_nodes, 1).to(adj.device)).squeeze() + deg_sqrt = deg.pow(-1 / 2) + adj_values = deg_sqrt[adj[1]] * adj_values * deg_sqrt[adj[0]] + row, col = adj[0], adj[1] + mask = row != col + adj_values[row[mask]] += 1 + return adj, adj_values + + +def aug_norm_adj(adj, adj_values, num_nodes): + adj, adj_values = add_remaining_self_loops(adj, adj_values, 1, num_nodes) + deg = spmm(adj, adj_values, torch.ones(num_nodes, 1).to(adj.device)).squeeze() + deg_sqrt = deg.pow(-1 / 2) + adj_values = deg_sqrt[adj[1]] * adj_values * deg_sqrt[adj[0]] + return adj, adj_values + + +def get_normalizer(normalization): + normalizer_dict = dict(AugNorm=aug_norm_adj, + BinggeNorm=bingge_norm_adj) + if not normalization in normalizer_dict: + raise NotImplementedError + return normalizer_dict[normalization] + + @register_model("gcn") class TKipfGCN(BaseModel): r"""The GCN model from the `"Semi-Supervised Classification with Graph Convolutional Networks" @@ -93,68 +135,67 @@ def add_args(parser): parser.add_argument("--hidden-size", type=int, default=64) parser.add_argument("--dropout", type=float, default=0.5) parser.add_argument("--num-layers", type=int, default=2) + parser.add_argument("--withloop", action="store_true", default=False) + parser.add_argument("--withbn", action="store_true", default=False) parser.add_argument("--dropedge", type=float, default=0.0) + parser.add_argument("--normalization", type=str, default="AugNorm") # fmt: on @classmethod def build_model_from_args(cls, args): - return cls(args.num_features, args.hidden_size, args.num_classes, args.dropout, args.num_layers, args.dropedge) + return cls(args.num_features, args.hidden_size, args.num_classes, args.dropout, args.num_layers, args.withloop, + args.withbn, args.dropedge, args.normalization) - def __init__(self, nfeat, nhid, nclass, dropout, num_layers, dropedge): + def __init__(self, nfeat, nhid, nclass, dropout, num_layers, withloop, withbn, dropedge, normalization): super(TKipfGCN, self).__init__() # self.gc1 = GraphConvolution(nfeat, nhid) # self.gc2 = GraphConvolution(nhid, nclass) + self.input_gc = GraphConvolution(nfeat, nhid, withloop=withloop, withbn=withbn) self.gcs = nn.ModuleList() - self.gcs.append(GraphConvolution(nfeat, nhid)) - for _ in range(num_layers - 2): - self.gcs.append(GraphConvolution(nhid, nhid)) - self.gcs.append(GraphConvolution(nhid, nclass)) + for _ in range(num_layers): + self.gcs.append(GraphConvolution(nhid, nhid, withloop=withloop, withbn=withbn)) + self.output_gc = GraphConvolution(nhid, nclass, withloop=withloop, withbn=withbn) self.dropout = dropout self.dropedge = dropedge # 0 correspond to no dropedge + self.normalization = normalization # self.nonlinear = nn.SELU() def forward(self, x, adj): device = x.device adj_values = torch.ones(adj.shape[1]).to(device) - original_adj = adj - original_adj_values = adj_values + original_adj = adj # (2, 9104) + original_adj_values = adj_values # (9104) original_x_shape = x.shape[0] + + adj, adj_values = drop_edge(original_adj, original_adj_values, self.dropedge, self.training) + # Add support to different normalizers + adj, adj_values = get_normalizer(self.normalization)(adj, adj_values, original_x_shape) + ''' + adj, adj_values = add_remaining_self_loops(adj, adj_values, 1, original_x_shape) + # print(adj.shape, adj_values.shape) # dropedge=0., (2, 12431), (12431) + deg = spmm(adj, adj_values, torch.ones(original_x_shape, 1).to(device)).squeeze() + # print(max(adj[0]), max(adj[1])) # 3326, 3326 + # print(deg.shape) # Tensor, (3327) + # print(max(deg), min(deg)) # 100, 1 + # print(adj_values.shape, max(adj_values), min(adj_values)) # (12431), every element is 1 + deg_sqrt = deg.pow(-1 / 2) + adj_values = deg_sqrt[adj[1]] * adj_values * deg_sqrt[adj[0]] + # print(adj_values.shape, max(adj_values), min(adj_values)) # (12431), max: 1, min: 0.01 + ''' + + # Input layer + x = self.input_gc(x, adj, adj_values) + x = F.relu(x) + x = F.dropout(x, self.dropout, training=self.training) + # Mid layers for idx, gc_layer in enumerate(self.gcs): - adj, adj_values = drop_edge(original_adj, original_adj_values, self.dropedge, self.training) - adj, adj_values = add_remaining_self_loops(adj, adj_values, 1, original_x_shape) - deg = spmm(adj, adj_values, torch.ones(original_x_shape, 1).to(device)).squeeze() - deg_sqrt = deg.pow(-1 / 2) - adj_values = deg_sqrt[adj[1]] * adj_values * deg_sqrt[adj[0]] - x = F.dropout(x, self.dropout, training=self.training) x = gc_layer(x, adj, adj_values) - if idx != len(self.gcs) - 1: - x = F.relu(x) + x = F.relu(x) + x = F.dropout(x, self.dropout, training=self.training) + # Output layer + x = self.output_gc(x, adj, adj_values) return F.log_softmax(x, dim=-1) - - # # TODO: implement drop edge here. - # adj, adj_values = drop_edge(original_adj, original_adj_values, self.dropedge, self.training) - # adj, adj_values = add_remaining_self_loops(adj, adj_values, 1, x.shape[0]) - # deg = spmm(adj, adj_values, torch.ones(x.shape[0], 1).to(device)).squeeze() - # deg_sqrt = deg.pow(-1/2) - # adj_values = deg_sqrt[adj[1]] * adj_values * deg_sqrt[adj[0]] - # - # x = F.dropout(x, self.dropout, training=self.training) - # x = F.relu(self.gc1(x, adj, adj_values)) - # # h1 = x - # x = F.dropout(x, self.dropout, training=self.training) - # adj, adj_values = drop_edge(original_adj, original_adj_values, self.dropedge, self.training) - # adj, adj_values = add_remaining_self_loops(adj, adj_values, 1, x.shape[0]) - # deg = spmm(adj, adj_values, torch.ones(x.shape[0], 1).to(device)).squeeze() - # deg_sqrt = deg.pow(-1 / 2) - # adj_values = deg_sqrt[adj[1]] * adj_values * deg_sqrt[adj[0]] - # x = self.gc2(x, adj, adj_values) - # - # # x = F.relu(x) - # # x = torch.sigmoid(x) - # # return x - # # h2 = x - # return F.log_softmax(x, dim=-1) def loss(self, data): return F.nll_loss( diff --git a/environment.yml b/environment.yml new file mode 100644 index 00000000..90fd8ba3 --- /dev/null +++ b/environment.yml @@ -0,0 +1,100 @@ +name: cogdl +channels: + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _pytorch_select=0.2=gpu_0 + - blas=1.0=mkl + - ca-certificates=2020.10.14=0 + - certifi=2020.6.20=pyhd3eb1b0_3 + - cffi=1.14.3=py36h261ae71_2 + - cudatoolkit=10.0.130=0 + - cudnn=7.6.5=cuda10.0_0 + - intel-openmp=2020.2=254 + - ld_impl_linux-64=2.33.1=h53a641e_7 + - libedit=3.1.20191231=h14c3975_1 + - libffi=3.3=he6710b0_2 + - libgcc-ng=9.1.0=hdf63c60_0 + - libstdcxx-ng=9.1.0=hdf63c60_0 + - mkl=2020.2=256 + - mkl-service=2.3.0=py36he904b0f_0 + - mkl_fft=1.2.0=py36h23d657b_0 + - mkl_random=1.1.1=py36h0573a6f_0 + - ncurses=6.2=he6710b0_1 + - ninja=1.10.1=py36hfd86e86_0 + - numpy=1.19.2=py36h54aff64_0 + - numpy-base=1.19.2=py36hfa32c7d_0 + - openssl=1.1.1h=h7b6447c_0 + - pip=20.2.4=py36h06a4308_0 + - pycparser=2.20=py_2 + - python=3.6.12=hcff3b4d_2 + - pytorch=1.3.1=cuda100py36h53c1284_0 + - readline=8.0=h7b6447c_0 + - setuptools=50.3.1=py36h06a4308_1 + - six=1.15.0=py36h06a4308_0 + - sqlite=3.33.0=h62c20be_0 + - tk=8.6.10=hbc83047_0 + - wheel=0.35.1=pyhd3eb1b0_0 + - xz=5.2.5=h7b6447c_0 + - zlib=1.2.11=h7b6447c_3 + - pip: + - alembic==1.4.3 + - ase==3.20.1 + - attrs==20.3.0 + - cached-property==1.5.2 + - chardet==3.0.4 + - cliff==3.5.0 + - cmaes==0.7.0 + - cmd2==1.4.0 + - colorama==0.4.4 + - colorlog==4.6.2 + - cycler==0.10.0 + - decorator==4.4.2 + - gensim==3.8.3 + - googledrivedownloader==0.4 + - grave==0.0.3 + - h5py==3.1.0 + - idna==2.10 + - importlib-metadata==3.0.0 + - isodate==0.6.0 + - jinja2==2.11.2 + - joblib==0.17.0 + - kiwisolver==1.3.1 + - littleutils==0.2.2 + - llvmlite==0.34.0 + - mako==1.1.3 + - markupsafe==1.1.1 + - matplotlib==3.3.3 + - networkx==2.5 + - numba==0.51.2 + - ogb==1.2.3 + - optuna==2.3.0 + - outdated==0.2.0 + - packaging==20.4 + - pandas==1.1.4 + - pbr==5.5.1 + - pillow==8.0.1 + - prettytable==0.7.2 + - pyparsing==2.4.7 + - pyperclip==1.8.1 + - python-dateutil==2.8.1 + - python-editor==1.0.4 + - pytz==2020.4 + - pyyaml==5.3.1 + - rdflib==5.0.0 + - requests==2.25.0 + - scikit-learn==0.23.2 + - scipy==1.5.4 + - smart-open==3.0.0 + - sqlalchemy==1.3.20 + - stevedore==3.2.2 + - tabulate==0.8.7 + - texttable==1.6.3 + - threadpoolctl==2.1.0 + - torch-geometric==1.6.1 + - tqdm==4.53.0 + - urllib3==1.26.2 + - wcwidth==0.2.5 + - zipp==3.4.0 +prefix: /home/yunfei/miniconda3/envs/cogdl + From 1ca7340764db3a0337e29ba928eb979112f2a3c6 Mon Sep 17 00:00:00 2001 From: Yunfei Li Date: Mon, 14 Dec 2020 19:26:24 +0800 Subject: [PATCH 3/4] Test dropedge gcn --- tests/tasks/test_node_classification.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/tasks/test_node_classification.py b/tests/tasks/test_node_classification.py index bd384883..a6d6c8aa 100644 --- a/tests/tasks/test_node_classification.py +++ b/tests/tasks/test_node_classification.py @@ -27,6 +27,11 @@ def test_gcn_cora(): args.task = "node_classification" args.dataset = "cora" args.model = "gcn" + args.num_layers = 0 + args.withloop = False + args.withbn = False + args.dropedge = 0.0 + args.normalization = "AugNorm" task = build_task(args) ret = task.train() assert 0 <= ret["Acc"] <= 1 From c8643e51d8aef811f47537944f13b7c49ab5443b Mon Sep 17 00:00:00 2001 From: Yunfei Li Date: Wed, 16 Dec 2020 16:08:38 +0800 Subject: [PATCH 4/4] Refactor --- cogdl/models/nn/dropedge_gcn.py | 207 ++++++++++++++++++++++++ cogdl/models/nn/gcn.py | 130 +++------------ match.yml | 1 + tests/tasks/test_node_classification.py | 16 ++ 4 files changed, 246 insertions(+), 108 deletions(-) create mode 100644 cogdl/models/nn/dropedge_gcn.py diff --git a/cogdl/models/nn/dropedge_gcn.py b/cogdl/models/nn/dropedge_gcn.py new file mode 100644 index 00000000..42724290 --- /dev/null +++ b/cogdl/models/nn/dropedge_gcn.py @@ -0,0 +1,207 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from .. import BaseModel, register_model +from cogdl.utils import add_remaining_self_loops, spmm, spmm_adj, add_self_loops + + +class GraphConvolution(nn.Module): + """ + Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 + """ + + def __init__(self, in_features, out_features, bias=True, withloop=False, withbn=False): + super(GraphConvolution, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = Parameter(torch.FloatTensor(in_features, out_features)) + if withloop: + self.self_weight = Parameter(torch.FloatTensor(in_features, out_features)) + else: + self.register_parameter("self_weight", None) + if withbn: + self.bn = torch.nn.BatchNorm1d(out_features) + else: + self.register_parameter("bn", None) + if bias: + self.bias = Parameter(torch.FloatTensor(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self): + stdv = 1.0 / math.sqrt(self.weight.size(1)) + self.weight.data.normal_(-stdv, stdv) + if self.self_weight is not None: + stdv = 1. / math.sqrt(self.self_weight.size(1)) + self.self_weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.normal_(-stdv, stdv) + + def forward(self, input, edge_index, edge_attr=None): + if edge_attr is None: + edge_attr = torch.ones(edge_index.shape[1]).float().to(input.device) + adj = torch.sparse_coo_tensor( + edge_index, + edge_attr, + (input.shape[0], input.shape[0]), + ).to(input.device) + support = torch.mm(input, self.weight) + output = torch.sparse.mm(adj, support) + # Self-loop + if self.self_weight is not None: + output = output + torch.mm(input, self.self_weight) + if self.bias is not None: + output = output + self.bias + if self.bn is not None: + output = self.bn(output) + return output + + def __repr__(self): + return ( + self.__class__.__name__ + + " (" + + str(self.in_features) + + " -> " + + str(self.out_features) + + ")" + ) + + +def drop_edge(adj, adj_values, dropedge_rate, train): + if train: + import numpy as np + n_edge = adj.shape[1] + _remaining_edges = np.arange(n_edge) + np.random.shuffle(_remaining_edges) + _remaining_edges = np.sort(_remaining_edges[:int((1 - dropedge_rate) * n_edge)]) + new_adj = adj[:, _remaining_edges] + new_adj_values = adj_values[_remaining_edges] + else: + new_adj = adj + new_adj_values = adj_values + return new_adj, new_adj_values + + +def bingge_norm_adj(adj, adj_values, num_nodes): + adj, adj_values = add_self_loops(adj, adj_values, 1, num_nodes) + deg = spmm(adj, adj_values, torch.ones(num_nodes, 1).to(adj.device)).squeeze() + deg_sqrt = deg.pow(-1 / 2) + adj_values = deg_sqrt[adj[1]] * adj_values * deg_sqrt[adj[0]] + row, col = adj[0], adj[1] + mask = row != col + adj_values[row[mask]] += 1 + return adj, adj_values + + +def aug_norm_adj(adj, adj_values, num_nodes): + adj, adj_values = add_remaining_self_loops(adj, adj_values, 1, num_nodes) + deg = spmm(adj, adj_values, torch.ones(num_nodes, 1).to(adj.device)).squeeze() + deg_sqrt = deg.pow(-1 / 2) + adj_values = deg_sqrt[adj[1]] * adj_values * deg_sqrt[adj[0]] + return adj, adj_values + + +def get_normalizer(normalization): + normalizer_dict = dict(AugNorm=aug_norm_adj, + BinggeNorm=bingge_norm_adj) + if not normalization in normalizer_dict: + raise NotImplementedError + return normalizer_dict[normalization] + + +@register_model("dropedge_gcn") +class DropEdgeGCN(BaseModel): + r"""The GCN model from the `"Semi-Supervised Classification with Graph Convolutional Networks" + `_ paper + + Args: + num_features (int) : Number of input features. + num_classes (int) : Number of classes. + hidden_size (int) : The dimension of node representation. + dropout (float) : Dropout rate for model training. + """ + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # fmt: off + parser.add_argument("--num-features", type=int) + parser.add_argument("--num-classes", type=int) + parser.add_argument("--hidden-size", type=int, default=64) + parser.add_argument("--dropout", type=float, default=0.5) + parser.add_argument("--num-layers", type=int, default=0) + parser.add_argument("--withloop", action="store_true", default=False) + parser.add_argument("--withbn", action="store_true", default=False) + parser.add_argument("--dropedge", type=float, default=0.0) + parser.add_argument("--normalization", type=str, default="AugNorm") + # fmt: on + + @classmethod + def build_model_from_args(cls, args): + return cls(args.num_features, args.hidden_size, args.num_classes, args.dropout, args.num_layers, args.withloop, + args.withbn, args.dropedge, args.normalization) + + def __init__(self, nfeat, nhid, nclass, dropout, num_layers, withloop, withbn, dropedge, normalization): + super(DropEdgeGCN, self).__init__() + + # self.gc1 = GraphConvolution(nfeat, nhid) + # self.gc2 = GraphConvolution(nhid, nclass) + self.input_gc = GraphConvolution(nfeat, nhid, withloop=withloop, withbn=withbn) + self.gcs = nn.ModuleList() + for _ in range(num_layers): + self.gcs.append(GraphConvolution(nhid, nhid, withloop=withloop, withbn=withbn)) + self.output_gc = GraphConvolution(nhid, nclass, withloop=withloop, withbn=withbn) + self.dropout = dropout + self.dropedge = dropedge # 0 correspond to no dropedge + self.normalization = normalization + # self.nonlinear = nn.SELU() + + def forward(self, x, adj): + device = x.device + adj_values = torch.ones(adj.shape[1]).to(device) + original_adj = adj # (2, 9104) + original_adj_values = adj_values # (9104) + original_x_shape = x.shape[0] + + adj, adj_values = drop_edge(original_adj, original_adj_values, self.dropedge, self.training) + # Add support to different normalizers + adj, adj_values = get_normalizer(self.normalization)(adj, adj_values, original_x_shape) + ''' + adj, adj_values = add_remaining_self_loops(adj, adj_values, 1, original_x_shape) + # print(adj.shape, adj_values.shape) # dropedge=0., (2, 12431), (12431) + deg = spmm(adj, adj_values, torch.ones(original_x_shape, 1).to(device)).squeeze() + # print(max(adj[0]), max(adj[1])) # 3326, 3326 + # print(deg.shape) # Tensor, (3327) + # print(max(deg), min(deg)) # 100, 1 + # print(adj_values.shape, max(adj_values), min(adj_values)) # (12431), every element is 1 + deg_sqrt = deg.pow(-1 / 2) + adj_values = deg_sqrt[adj[1]] * adj_values * deg_sqrt[adj[0]] + # print(adj_values.shape, max(adj_values), min(adj_values)) # (12431), max: 1, min: 0.01 + ''' + + # Input layer + x = self.input_gc(x, adj, adj_values) + x = F.relu(x) + x = F.dropout(x, self.dropout, training=self.training) + # Mid layers + for idx, gc_layer in enumerate(self.gcs): + x = gc_layer(x, adj, adj_values) + x = F.relu(x) + x = F.dropout(x, self.dropout, training=self.training) + # Output layer + x = self.output_gc(x, adj, adj_values) + return F.log_softmax(x, dim=-1) + + def loss(self, data): + return F.nll_loss( + self.forward(data.x, data.edge_index)[data.train_mask], + data.y[data.train_mask], + ) + + def predict(self, data): + return self.forward(data.x, data.edge_index) diff --git a/cogdl/models/nn/gcn.py b/cogdl/models/nn/gcn.py index 369bab9a..2ca5abc5 100644 --- a/cogdl/models/nn/gcn.py +++ b/cogdl/models/nn/gcn.py @@ -6,7 +6,7 @@ from torch.nn.parameter import Parameter from .. import BaseModel, register_model -from cogdl.utils import add_remaining_self_loops, spmm, spmm_adj, add_self_loops +from cogdl.utils import add_remaining_self_loops, spmm, spmm_adj class GraphConvolution(nn.Module): @@ -14,19 +14,11 @@ class GraphConvolution(nn.Module): Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 """ - def __init__(self, in_features, out_features, bias=True, withloop=False, withbn=False): + def __init__(self, in_features, out_features, bias=True): super(GraphConvolution, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.FloatTensor(in_features, out_features)) - if withloop: - self.self_weight = Parameter(torch.FloatTensor(in_features, out_features)) - else: - self.register_parameter("self_weight", None) - if withbn: - self.bn = torch.nn.BatchNorm1d(out_features) - else: - self.register_parameter("bn", None) if bias: self.bias = Parameter(torch.FloatTensor(out_features)) else: @@ -36,9 +28,6 @@ def __init__(self, in_features, out_features, bias=True, withloop=False, withbn= def reset_parameters(self): stdv = 1.0 / math.sqrt(self.weight.size(1)) self.weight.data.normal_(-stdv, stdv) - if self.self_weight is not None: - stdv = 1. / math.sqrt(self.self_weight.size(1)) - self.self_weight.data.uniform_(-stdv, stdv) if self.bias is not None: self.bias.data.normal_(-stdv, stdv) @@ -51,15 +40,11 @@ def forward(self, input, edge_index, edge_attr=None): (input.shape[0], input.shape[0]), ).to(input.device) support = torch.mm(input, self.weight) - output = torch.sparse.mm(adj, support) - # Self-loop - if self.self_weight is not None: - output = output + torch.mm(input, self.self_weight) + output = torch.spmm(adj, support) if self.bias is not None: - output = output + self.bias - if self.bn is not None: - output = self.bn(output) - return output + return output + self.bias + else: + return output def __repr__(self): return ( @@ -72,48 +57,6 @@ def __repr__(self): ) -def drop_edge(adj, adj_values, dropedge_rate, train): - if train: - import numpy as np - n_edge = adj.shape[1] - _remaining_edges = np.arange(n_edge) - np.random.shuffle(_remaining_edges) - _remaining_edges = np.sort(_remaining_edges[:int((1 - dropedge_rate) * n_edge)]) - new_adj = adj[:, _remaining_edges] - new_adj_values = adj_values[_remaining_edges] - else: - new_adj = adj - new_adj_values = adj_values - return new_adj, new_adj_values - - -def bingge_norm_adj(adj, adj_values, num_nodes): - adj, adj_values = add_self_loops(adj, adj_values, 1, num_nodes) - deg = spmm(adj, adj_values, torch.ones(num_nodes, 1).to(adj.device)).squeeze() - deg_sqrt = deg.pow(-1 / 2) - adj_values = deg_sqrt[adj[1]] * adj_values * deg_sqrt[adj[0]] - row, col = adj[0], adj[1] - mask = row != col - adj_values[row[mask]] += 1 - return adj, adj_values - - -def aug_norm_adj(adj, adj_values, num_nodes): - adj, adj_values = add_remaining_self_loops(adj, adj_values, 1, num_nodes) - deg = spmm(adj, adj_values, torch.ones(num_nodes, 1).to(adj.device)).squeeze() - deg_sqrt = deg.pow(-1 / 2) - adj_values = deg_sqrt[adj[1]] * adj_values * deg_sqrt[adj[0]] - return adj, adj_values - - -def get_normalizer(normalization): - normalizer_dict = dict(AugNorm=aug_norm_adj, - BinggeNorm=bingge_norm_adj) - if not normalization in normalizer_dict: - raise NotImplementedError - return normalizer_dict[normalization] - - @register_model("gcn") class TKipfGCN(BaseModel): r"""The GCN model from the `"Semi-Supervised Classification with Graph Convolutional Networks" @@ -134,67 +77,38 @@ def add_args(parser): parser.add_argument("--num-classes", type=int) parser.add_argument("--hidden-size", type=int, default=64) parser.add_argument("--dropout", type=float, default=0.5) - parser.add_argument("--num-layers", type=int, default=2) - parser.add_argument("--withloop", action="store_true", default=False) - parser.add_argument("--withbn", action="store_true", default=False) - parser.add_argument("--dropedge", type=float, default=0.0) - parser.add_argument("--normalization", type=str, default="AugNorm") # fmt: on @classmethod def build_model_from_args(cls, args): - return cls(args.num_features, args.hidden_size, args.num_classes, args.dropout, args.num_layers, args.withloop, - args.withbn, args.dropedge, args.normalization) + return cls(args.num_features, args.hidden_size, args.num_classes, args.dropout) - def __init__(self, nfeat, nhid, nclass, dropout, num_layers, withloop, withbn, dropedge, normalization): + def __init__(self, nfeat, nhid, nclass, dropout): super(TKipfGCN, self).__init__() - # self.gc1 = GraphConvolution(nfeat, nhid) - # self.gc2 = GraphConvolution(nhid, nclass) - self.input_gc = GraphConvolution(nfeat, nhid, withloop=withloop, withbn=withbn) - self.gcs = nn.ModuleList() - for _ in range(num_layers): - self.gcs.append(GraphConvolution(nhid, nhid, withloop=withloop, withbn=withbn)) - self.output_gc = GraphConvolution(nhid, nclass, withloop=withloop, withbn=withbn) + self.gc1 = GraphConvolution(nfeat, nhid) + self.gc2 = GraphConvolution(nhid, nclass) self.dropout = dropout - self.dropedge = dropedge # 0 correspond to no dropedge - self.normalization = normalization # self.nonlinear = nn.SELU() def forward(self, x, adj): device = x.device adj_values = torch.ones(adj.shape[1]).to(device) - original_adj = adj # (2, 9104) - original_adj_values = adj_values # (9104) - original_x_shape = x.shape[0] - - adj, adj_values = drop_edge(original_adj, original_adj_values, self.dropedge, self.training) - # Add support to different normalizers - adj, adj_values = get_normalizer(self.normalization)(adj, adj_values, original_x_shape) - ''' - adj, adj_values = add_remaining_self_loops(adj, adj_values, 1, original_x_shape) - # print(adj.shape, adj_values.shape) # dropedge=0., (2, 12431), (12431) - deg = spmm(adj, adj_values, torch.ones(original_x_shape, 1).to(device)).squeeze() - # print(max(adj[0]), max(adj[1])) # 3326, 3326 - # print(deg.shape) # Tensor, (3327) - # print(max(deg), min(deg)) # 100, 1 - # print(adj_values.shape, max(adj_values), min(adj_values)) # (12431), every element is 1 - deg_sqrt = deg.pow(-1 / 2) + adj, adj_values = add_remaining_self_loops(adj, adj_values, 1, x.shape[0]) + deg = spmm(adj, adj_values, torch.ones(x.shape[0], 1).to(device)).squeeze() + deg_sqrt = deg.pow(-1/2) adj_values = deg_sqrt[adj[1]] * adj_values * deg_sqrt[adj[0]] - # print(adj_values.shape, max(adj_values), min(adj_values)) # (12431), max: 1, min: 0.01 - ''' - # Input layer - x = self.input_gc(x, adj, adj_values) - x = F.relu(x) x = F.dropout(x, self.dropout, training=self.training) - # Mid layers - for idx, gc_layer in enumerate(self.gcs): - x = gc_layer(x, adj, adj_values) - x = F.relu(x) - x = F.dropout(x, self.dropout, training=self.training) - # Output layer - x = self.output_gc(x, adj, adj_values) + x = F.relu(self.gc1(x, adj, adj_values)) + # h1 = x + x = F.dropout(x, self.dropout, training=self.training) + x = self.gc2(x, adj, adj_values) + + # x = F.relu(x) + # x = torch.sigmoid(x) + # return x + # h2 = x return F.log_softmax(x, dim=-1) def loss(self, data): diff --git a/match.yml b/match.yml index e224b6bb..06abd2ed 100644 --- a/match.yml +++ b/match.yml @@ -1,6 +1,7 @@ node_classification: - model: - gcn + - dropedge_gcn - gat - drgat - grand diff --git a/tests/tasks/test_node_classification.py b/tests/tasks/test_node_classification.py index a6d6c8aa..59828f9a 100644 --- a/tests/tasks/test_node_classification.py +++ b/tests/tasks/test_node_classification.py @@ -37,6 +37,21 @@ def test_gcn_cora(): assert 0 <= ret["Acc"] <= 1 +def test_dropedge_gcn_cora(): + args = get_default_args() + args.task = "node_classification" + args.dataset = "cora" + args.model = "dropedge_gcn" + args.num_layers = 0 + args.withloop = False + args.withbn = False + args.dropedge = 0.0 + args.normalization = "AugNorm" + task = build_task(args) + ret = task.train() + assert 0 <= ret["Acc"] <= 1 + + def test_gat_cora(): args = get_default_args() args.task = "node_classification" @@ -343,6 +358,7 @@ def test_gpt_gnn_cora(): if __name__ == "__main__": test_gcn_cora() + test_dropedge_gcn_cora() test_gat_cora() test_mlp_pubmed() test_mixhop_citeseer()