Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DropEdge GCN #89

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 207 additions & 0 deletions cogdl/models/nn/dropedge_gcn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from .. import BaseModel, register_model
from cogdl.utils import add_remaining_self_loops, spmm, spmm_adj, add_self_loops


class GraphConvolution(nn.Module):
"""
Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
"""

def __init__(self, in_features, out_features, bias=True, withloop=False, withbn=False):
super(GraphConvolution, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.FloatTensor(in_features, out_features))
if withloop:
self.self_weight = Parameter(torch.FloatTensor(in_features, out_features))
else:
self.register_parameter("self_weight", None)
if withbn:
self.bn = torch.nn.BatchNorm1d(out_features)
else:
self.register_parameter("bn", None)
if bias:
self.bias = Parameter(torch.FloatTensor(out_features))
else:
self.register_parameter("bias", None)
self.reset_parameters()

def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.weight.size(1))
self.weight.data.normal_(-stdv, stdv)
if self.self_weight is not None:
stdv = 1. / math.sqrt(self.self_weight.size(1))
self.self_weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.normal_(-stdv, stdv)

def forward(self, input, edge_index, edge_attr=None):
if edge_attr is None:
edge_attr = torch.ones(edge_index.shape[1]).float().to(input.device)
adj = torch.sparse_coo_tensor(
edge_index,
edge_attr,
(input.shape[0], input.shape[0]),
).to(input.device)
support = torch.mm(input, self.weight)
output = torch.sparse.mm(adj, support)
# Self-loop
if self.self_weight is not None:
output = output + torch.mm(input, self.self_weight)
if self.bias is not None:
output = output + self.bias
if self.bn is not None:
output = self.bn(output)
return output

def __repr__(self):
return (
self.__class__.__name__
+ " ("
+ str(self.in_features)
+ " -> "
+ str(self.out_features)
+ ")"
)


def drop_edge(adj, adj_values, dropedge_rate, train):
if train:
import numpy as np
n_edge = adj.shape[1]
_remaining_edges = np.arange(n_edge)
np.random.shuffle(_remaining_edges)
_remaining_edges = np.sort(_remaining_edges[:int((1 - dropedge_rate) * n_edge)])
new_adj = adj[:, _remaining_edges]
new_adj_values = adj_values[_remaining_edges]
else:
new_adj = adj
new_adj_values = adj_values
return new_adj, new_adj_values


def bingge_norm_adj(adj, adj_values, num_nodes):
adj, adj_values = add_self_loops(adj, adj_values, 1, num_nodes)
deg = spmm(adj, adj_values, torch.ones(num_nodes, 1).to(adj.device)).squeeze()
deg_sqrt = deg.pow(-1 / 2)
adj_values = deg_sqrt[adj[1]] * adj_values * deg_sqrt[adj[0]]
row, col = adj[0], adj[1]
mask = row != col
adj_values[row[mask]] += 1
return adj, adj_values


def aug_norm_adj(adj, adj_values, num_nodes):
adj, adj_values = add_remaining_self_loops(adj, adj_values, 1, num_nodes)
deg = spmm(adj, adj_values, torch.ones(num_nodes, 1).to(adj.device)).squeeze()
deg_sqrt = deg.pow(-1 / 2)
adj_values = deg_sqrt[adj[1]] * adj_values * deg_sqrt[adj[0]]
return adj, adj_values


def get_normalizer(normalization):
normalizer_dict = dict(AugNorm=aug_norm_adj,
BinggeNorm=bingge_norm_adj)
if not normalization in normalizer_dict:
raise NotImplementedError
return normalizer_dict[normalization]


@register_model("dropedge_gcn")
class DropEdgeGCN(BaseModel):
r"""The GCN model from the `"Semi-Supervised Classification with Graph Convolutional Networks"
<https://arxiv.org/abs/1609.02907>`_ paper

Args:
num_features (int) : Number of input features.
num_classes (int) : Number of classes.
hidden_size (int) : The dimension of node representation.
dropout (float) : Dropout rate for model training.
"""

@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument("--num-features", type=int)
parser.add_argument("--num-classes", type=int)
parser.add_argument("--hidden-size", type=int, default=64)
parser.add_argument("--dropout", type=float, default=0.5)
parser.add_argument("--num-layers", type=int, default=0)
parser.add_argument("--withloop", action="store_true", default=False)
parser.add_argument("--withbn", action="store_true", default=False)
parser.add_argument("--dropedge", type=float, default=0.0)
parser.add_argument("--normalization", type=str, default="AugNorm")
# fmt: on

@classmethod
def build_model_from_args(cls, args):
return cls(args.num_features, args.hidden_size, args.num_classes, args.dropout, args.num_layers, args.withloop,
args.withbn, args.dropedge, args.normalization)

def __init__(self, nfeat, nhid, nclass, dropout, num_layers, withloop, withbn, dropedge, normalization):
super(DropEdgeGCN, self).__init__()

# self.gc1 = GraphConvolution(nfeat, nhid)
# self.gc2 = GraphConvolution(nhid, nclass)
self.input_gc = GraphConvolution(nfeat, nhid, withloop=withloop, withbn=withbn)
self.gcs = nn.ModuleList()
for _ in range(num_layers):
self.gcs.append(GraphConvolution(nhid, nhid, withloop=withloop, withbn=withbn))
self.output_gc = GraphConvolution(nhid, nclass, withloop=withloop, withbn=withbn)
self.dropout = dropout
self.dropedge = dropedge # 0 correspond to no dropedge
self.normalization = normalization
# self.nonlinear = nn.SELU()

def forward(self, x, adj):
device = x.device
adj_values = torch.ones(adj.shape[1]).to(device)
original_adj = adj # (2, 9104)
original_adj_values = adj_values # (9104)
original_x_shape = x.shape[0]

adj, adj_values = drop_edge(original_adj, original_adj_values, self.dropedge, self.training)
# Add support to different normalizers
adj, adj_values = get_normalizer(self.normalization)(adj, adj_values, original_x_shape)
'''
adj, adj_values = add_remaining_self_loops(adj, adj_values, 1, original_x_shape)
# print(adj.shape, adj_values.shape) # dropedge=0., (2, 12431), (12431)
deg = spmm(adj, adj_values, torch.ones(original_x_shape, 1).to(device)).squeeze()
# print(max(adj[0]), max(adj[1])) # 3326, 3326
# print(deg.shape) # Tensor, (3327)
# print(max(deg), min(deg)) # 100, 1
# print(adj_values.shape, max(adj_values), min(adj_values)) # (12431), every element is 1
deg_sqrt = deg.pow(-1 / 2)
adj_values = deg_sqrt[adj[1]] * adj_values * deg_sqrt[adj[0]]
# print(adj_values.shape, max(adj_values), min(adj_values)) # (12431), max: 1, min: 0.01
'''

# Input layer
x = self.input_gc(x, adj, adj_values)
x = F.relu(x)
x = F.dropout(x, self.dropout, training=self.training)
# Mid layers
for idx, gc_layer in enumerate(self.gcs):
x = gc_layer(x, adj, adj_values)
x = F.relu(x)
x = F.dropout(x, self.dropout, training=self.training)
# Output layer
x = self.output_gc(x, adj, adj_values)
return F.log_softmax(x, dim=-1)

def loss(self, data):
return F.nll_loss(
self.forward(data.x, data.edge_index)[data.train_mask],
data.y[data.train_mask],
)

def predict(self, data):
return self.forward(data.x, data.edge_index)
100 changes: 100 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
name: cogdl
channels:
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _pytorch_select=0.2=gpu_0
- blas=1.0=mkl
- ca-certificates=2020.10.14=0
- certifi=2020.6.20=pyhd3eb1b0_3
- cffi=1.14.3=py36h261ae71_2
- cudatoolkit=10.0.130=0
- cudnn=7.6.5=cuda10.0_0
- intel-openmp=2020.2=254
- ld_impl_linux-64=2.33.1=h53a641e_7
- libedit=3.1.20191231=h14c3975_1
- libffi=3.3=he6710b0_2
- libgcc-ng=9.1.0=hdf63c60_0
- libstdcxx-ng=9.1.0=hdf63c60_0
- mkl=2020.2=256
- mkl-service=2.3.0=py36he904b0f_0
- mkl_fft=1.2.0=py36h23d657b_0
- mkl_random=1.1.1=py36h0573a6f_0
- ncurses=6.2=he6710b0_1
- ninja=1.10.1=py36hfd86e86_0
- numpy=1.19.2=py36h54aff64_0
- numpy-base=1.19.2=py36hfa32c7d_0
- openssl=1.1.1h=h7b6447c_0
- pip=20.2.4=py36h06a4308_0
- pycparser=2.20=py_2
- python=3.6.12=hcff3b4d_2
- pytorch=1.3.1=cuda100py36h53c1284_0
- readline=8.0=h7b6447c_0
- setuptools=50.3.1=py36h06a4308_1
- six=1.15.0=py36h06a4308_0
- sqlite=3.33.0=h62c20be_0
- tk=8.6.10=hbc83047_0
- wheel=0.35.1=pyhd3eb1b0_0
- xz=5.2.5=h7b6447c_0
- zlib=1.2.11=h7b6447c_3
- pip:
- alembic==1.4.3
- ase==3.20.1
- attrs==20.3.0
- cached-property==1.5.2
- chardet==3.0.4
- cliff==3.5.0
- cmaes==0.7.0
- cmd2==1.4.0
- colorama==0.4.4
- colorlog==4.6.2
- cycler==0.10.0
- decorator==4.4.2
- gensim==3.8.3
- googledrivedownloader==0.4
- grave==0.0.3
- h5py==3.1.0
- idna==2.10
- importlib-metadata==3.0.0
- isodate==0.6.0
- jinja2==2.11.2
- joblib==0.17.0
- kiwisolver==1.3.1
- littleutils==0.2.2
- llvmlite==0.34.0
- mako==1.1.3
- markupsafe==1.1.1
- matplotlib==3.3.3
- networkx==2.5
- numba==0.51.2
- ogb==1.2.3
- optuna==2.3.0
- outdated==0.2.0
- packaging==20.4
- pandas==1.1.4
- pbr==5.5.1
- pillow==8.0.1
- prettytable==0.7.2
- pyparsing==2.4.7
- pyperclip==1.8.1
- python-dateutil==2.8.1
- python-editor==1.0.4
- pytz==2020.4
- pyyaml==5.3.1
- rdflib==5.0.0
- requests==2.25.0
- scikit-learn==0.23.2
- scipy==1.5.4
- smart-open==3.0.0
- sqlalchemy==1.3.20
- stevedore==3.2.2
- tabulate==0.8.7
- texttable==1.6.3
- threadpoolctl==2.1.0
- torch-geometric==1.6.1
- tqdm==4.53.0
- urllib3==1.26.2
- wcwidth==0.2.5
- zipp==3.4.0
prefix: /home/yunfei/miniconda3/envs/cogdl

1 change: 1 addition & 0 deletions match.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
node_classification:
- model:
- gcn
- dropedge_gcn
- gat
- drgat
- grand
Expand Down
21 changes: 21 additions & 0 deletions tests/tasks/test_node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,26 @@ def test_gcn_cora():
args.task = "node_classification"
args.dataset = "cora"
args.model = "gcn"
args.num_layers = 0
args.withloop = False
args.withbn = False
args.dropedge = 0.0
args.normalization = "AugNorm"
task = build_task(args)
ret = task.train()
assert 0 <= ret["Acc"] <= 1


def test_dropedge_gcn_cora():
args = get_default_args()
args.task = "node_classification"
args.dataset = "cora"
args.model = "dropedge_gcn"
args.num_layers = 0
args.withloop = False
args.withbn = False
args.dropedge = 0.0
args.normalization = "AugNorm"
task = build_task(args)
ret = task.train()
assert 0 <= ret["Acc"] <= 1
Expand Down Expand Up @@ -338,6 +358,7 @@ def test_gpt_gnn_cora():

if __name__ == "__main__":
test_gcn_cora()
test_dropedge_gcn_cora()
test_gat_cora()
test_mlp_pubmed()
test_mixhop_citeseer()
Expand Down