Skip to content

Commit

Permalink
[Bugfix] Fix sample adj (#301)
Browse files Browse the repository at this point in the history
* Fix sample_adj & docs
  • Loading branch information
cenyk1230 authored Nov 5, 2021
1 parent e9f9e05 commit 096f36b
Show file tree
Hide file tree
Showing 29 changed files with 34 additions and 62 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
[submodule "third_party/dgNN"]
ignore = dirty
path = third_party/dgNN
url = https://github.com/dgSPARSE/dgNN
branch = main
[submodule "third_party/actnn"]
ignore = dirty
path = third_party/actnn
url = https://github.com/ucbrise/actnn
branch = main
[submodule "third_party/fastmoe"]
ignore = dirty
path = third_party/fastmoe
url = https://github.com/laekov/fastmoe
branch = master
3 changes: 0 additions & 3 deletions cogdl/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,13 +782,10 @@ def sample_adj(self, batch, size=-1, replace=True):
if sample_adj_c is not None:
if not torch.is_tensor(batch):
batch = torch.tensor(batch, dtype=torch.long)
# (row_ptr, col_indices, nodes, edges) = sample_adj_c(self._adj.row_ptr, self._adj.col, batch, size, replace)
(row_ptr, col_indices, nodes, edges) = sample_adj_c(
self._adj.row_indptr, self.col_indices, batch, size, replace
)
else:
if not (batch[1:] > batch[:-1]).all():
batch = batch.sort()[0]
if torch.is_tensor(batch):
batch = batch.cpu().numpy()
if self.__is_train__ and self._adj_train is not None:
Expand Down
2 changes: 1 addition & 1 deletion cogdl/layers/deepergcn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def forward(self, graph, x):

class ResGNNLayer(nn.Module):
"""
Implementation of DeeperGCN in paper `"DeeperGCN: All You Need to Train Deeper GCNs"` <https://arxiv.org/abs/2006.07739>
Implementation of DeeperGCN in paper `"DeeperGCN: All You Need to Train Deeper GCNs" <https://arxiv.org/abs/2006.07739>`_
Parameters
-----------
Expand Down
2 changes: 1 addition & 1 deletion cogdl/layers/disengcn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

class DisenGCNLayer(nn.Module):
"""
Implementation of "Disentangled Graph Convolutional Networks" <http://proceedings.mlr.press/v97/ma19a.html>.
Implementation of `"Disentangled Graph Convolutional Networks" <http://proceedings.mlr.press/v97/ma19a.html>`_.
"""

def __init__(self, in_feats, out_feats, K, iterations, tau=1.0, activation="leaky_relu"):
Expand Down
3 changes: 0 additions & 3 deletions cogdl/models/emb/deepwalk.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ def __init__(self, dimension, walk_length, walk_num, window_size, worker, iterat
self.worker = worker
self.iteration = iteration

def train(self, graph, embedding_model_creator=Word2Vec, return_dict=False):
return self.forward(graph, embedding_model_creator, return_dict)

def forward(self, graph, embedding_model_creator=Word2Vec, return_dict=False):
nx_g = graph.to_networkx()
self.G = nx_g
Expand Down
3 changes: 0 additions & 3 deletions cogdl/models/emb/dngr.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,6 @@ def get_emb(self, matrix):
emb_matrix = preprocessing.normalize(emb_matrix, "l2")
return emb_matrix

def train(self, graph, return_dict=False):
return self.forward(graph, return_dict=return_dict)

def forward(self, graph, return_dict=False):
device = "cuda" if torch.cuda.is_available() and not self.cpu else "cpu"
G = graph.to_networkx()
Expand Down
3 changes: 0 additions & 3 deletions cogdl/models/emb/gatne.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,6 @@ def __init__(
self.schema = schema
self.multiplicity = True

def train(self, network_data):
return self.forward(network_data)

def forward(self, network_data):
device = "cpu" if not torch.cuda.is_available() else "cuda"
all_walks = generate_walks(network_data, self.walk_num, self.walk_length, schema=self.schema)
Expand Down
3 changes: 0 additions & 3 deletions cogdl/models/emb/grarep.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ def __init__(self, dimension, step):
self.dimension = dimension
self.step = step

def train(self, graph, return_dict=False):
return self.forward(graph, return_dict)

def forward(self, graph, return_dict=False):
self.G = graph.to_networkx()
self.num_node = self.G.number_of_nodes()
Expand Down
3 changes: 0 additions & 3 deletions cogdl/models/emb/hin2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,6 @@ def __init__(self, hidden_dim, walk_length, walk_num, batch_size, hop, negative,
self.cpu = cpu

def forward(self, data):
return self.train(data)

def train(self, data):
device = "cpu" if not torch.cuda.is_available() or self.cpu else "cuda"
G = nx.DiGraph()
row, col = data.edge_index
Expand Down
3 changes: 0 additions & 3 deletions cogdl/models/emb/hope.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ def __init__(self, dimension, beta):
self.dimension = dimension
self.beta = beta

def train(self, graph, return_dict=False):
return self.forward(graph, return_dict)

def forward(self, graph, return_dict=False):
r"""The author claim that Katz has superior performance in related tasks
S_katz = (M_g)^-1 * M_l = (I - beta*A)^-1 * beta*A = (I - beta*A)^-1 * (I - (I -beta*A))
Expand Down
3 changes: 0 additions & 3 deletions cogdl/models/emb/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@ def __init__(self, dimension, walk_length, walk_num, negative, batch_size, alpha
self.init_alpha = alpha
self.order = order

def train(self, graph, return_dict=False):
return self.forward(graph, return_dict)

def forward(self, graph, return_dict=False):
# run LINE algorithm, 1-order, 2-order or 3(1-order + 2-order)
nx_g = graph.to_networkx()
Expand Down
3 changes: 0 additions & 3 deletions cogdl/models/emb/metapath2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,6 @@ def __init__(self, dimension, walk_length, walk_num, window_size, worker, iterat
self.node_type = None

def forward(self, data):
return self.train(data)

def train(self, data):
G = nx.DiGraph()
row, col = data.edge_index
G.add_edges_from(list(zip(row.numpy(), col.numpy())))
Expand Down
3 changes: 0 additions & 3 deletions cogdl/models/emb/netmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ def __init__(self, dimension, window_size, rank, negative, is_large=False):
self.negative = negative
self.is_large = is_large

def train(self, graph, return_dict=False):
return self.forward(graph, return_dict)

def forward(self, graph, return_dict=False):
nx_g = graph.to_networkx()
A = sp.csr_matrix(nx.adjacency_matrix(nx_g))
Expand Down
3 changes: 0 additions & 3 deletions cogdl/models/emb/netsmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,6 @@ def __init__(self, dimension, window_size, negative, num_round, worker):
self.worker = worker
self.num_round = num_round

def train(self, graph, return_dict=False):
return self.forward(graph, return_dict)

def forward(self, graph, return_dict=False):
self.G = graph.to_networkx()
node2id = dict([(node, vid) for vid, node in enumerate(self.G.nodes())])
Expand Down
3 changes: 0 additions & 3 deletions cogdl/models/emb/node2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,6 @@ def __init__(self, dimension, walk_length, walk_num, window_size, worker, iterat
self.p = p
self.q = q

def train(self, graph, return_dict=False):
return self.forward(graph, return_dict)

def forward(self, graph, return_dict=False):
G = graph.to_networkx()
self.G = G
Expand Down
3 changes: 0 additions & 3 deletions cogdl/models/emb/prone.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ def __init__(self, dimension, step, mu, theta):
self.mu = mu
self.theta = theta

def train(self, graph, return_dict=False):
return self.forward(graph, return_dict)

def forward(self, graph: Graph, return_dict=False):
nx_g = graph.to_networkx()
self.matrix0 = sp.csr_matrix(nx.adjacency_matrix(nx_g))
Expand Down
3 changes: 0 additions & 3 deletions cogdl/models/emb/pte.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,6 @@ def __init__(self, dimension, walk_length, walk_num, negative, batch_size, alpha
self.init_alpha = alpha

def forward(self, data):
return self.train(data)

def train(self, data):
G = nx.DiGraph()
row, col = data.edge_index
G.add_edges_from(list(zip(row.numpy(), col.numpy())))
Expand Down
3 changes: 0 additions & 3 deletions cogdl/models/emb/sdne.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,6 @@ def __init__(self, hidden_size1, hidden_size2, droput, alpha, beta, nu1, nu2, ep
self.lr = lr
self.cpu = cpu

def train(self, graph, return_dict=False):
return self.forward(graph, return_dict)

def forward(self, graph, return_dict=False):
device = "cuda" if torch.cuda.is_available() and not self.cpu else "cpu"
G = graph.to_networkx()
Expand Down
3 changes: 0 additions & 3 deletions cogdl/models/emb/spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ def __init__(self, hidden_size):
super(Spectral, self).__init__()
self.dimension = hidden_size

def train(self, graph, return_dict=False):
return self.forward(graph, return_dict)

def forward(self, graph, return_dict=False):
nx_g = graph.to_networkx()
matrix = nx.normalized_laplacian_matrix(nx_g).todense()
Expand Down
19 changes: 19 additions & 0 deletions cogdl/models/nn/deepergcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,25 @@


class DeeperGCN(BaseModel):
"""Implementation of DeeperGCN in paper `"DeeperGCN: All You Need to Train Deeper GCNs" <https://arxiv.org/abs/2006.07739>`_
Args:
in_feat (int): the dimension of input features
hidden_size (int): the dimension of hidden representation
out_feat (int): the dimension of output features
num_layers (int): the number of layers
activation (str, optional): activation function. Defaults to "relu".
dropout (float, optional): dropout rate. Defaults to 0.0.
aggr (str, optional): aggregation function. Defaults to "max".
beta (float, optional): a coefficient for aggregation function. Defaults to 1.0.
p (float, optional): a coefficient for aggregation function. Defaults to 1.0.
learn_beta (bool, optional): whether beta is learnable. Defaults to False.
learn_p (bool, optional): whether p is learnable. Defaults to False.
learn_msg_scale (bool, optional): whether message scale is learnable. Defaults to True.
use_msg_norm (bool, optional): use message norm or not. Defaults to False.
edge_attr_size (int, optional): the dimension of edge features. Defaults to None.
"""

@staticmethod
def add_args(parser):
# fmt: off
Expand Down
2 changes: 1 addition & 1 deletion cogdl/models/nn/gcnii.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class GCNII(BaseModel):
"""
Implementation of GCNII in paper `"Simple and Deep Graph Convolutional Networks"` <https://arxiv.org/abs/2007.02133>.
Implementation of GCNII in paper `"Simple and Deep Graph Convolutional Networks" <https://arxiv.org/abs/2007.02133>`_.
Parameters
-----------
Expand Down
2 changes: 1 addition & 1 deletion cogdl/models/nn/gdc_gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class GDC_GCN(BaseModel):
t (float) : Heat polynomial filter param
k (int) : Top k nodes retained during sparsification.
eps (float) : Threshold for clipping.
gdc_type (str) : "none", "ppr", "heat"
gdc_type (str) : "none", "ppr", "heat"
"""

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion cogdl/models/nn/m3s.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, num_features, hidden_size, num_classes, dropout):
self.gcn1 = GCNLayer(num_features, hidden_size)
self.gcn2 = GCNLayer(hidden_size, num_classes)

def get_embeddings(self, graph):
def embed(self, graph):
graph.sym_norm()
h = graph.x
h = self.gcn1(graph, h)
Expand Down
2 changes: 1 addition & 1 deletion cogdl/models/nn/moe_gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(
self.act = get_activation(activation)
self.final_cls = nn.Linear(hidden_size, out_feats)

def get_embeddings(self, graph):
def embed(self, graph):
graph.sym_norm()
h = graph.x
for i in range(self.num_layers - 1):
Expand Down
2 changes: 1 addition & 1 deletion cogdl/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def __call__(self, edge_index, x=None, edge_weight=None):
edge_index = (edge_index[:, 0], edge_index[:, 1])
data = Graph(edge_index=edge_index, edge_weight=edge_weight)
self.model = build_model(self.args)
embeddings = self.model.train(data)
embeddings = self.model(data)
elif self.method_type == "gnn":
num_nodes = edge_index.max().item() + 1
if x is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, model, hidden_size=200):
self.hidden_size = hidden_size

def train_step(self, batch):
embeddings = self.model.train(batch)
embeddings = self.model(batch)
embeddings = np.hstack((embeddings, batch.x.numpy()))

return embeddings
Expand Down
2 changes: 1 addition & 1 deletion cogdl/wrappers/model_wrapper/node_classification/m3s_mw.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def pre_stage(self, stage, data_w: DataWrapper):
num_nodes = graph.num_nodes

with torch.no_grad():
emb = self.model.get_embeddings(graph)
emb = self.model.embed(graph)

confidence_ranking = np.zeros([num_classes, num_nodes], dtype=int)
kmeans = KMeans(n_clusters=self.num_clusters, random_state=0).fit(emb)
Expand Down
2 changes: 0 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@ def find_version(filename):
source_suffix = [".rst", ".md"]
# source_suffix = ".rst"

autodoc_mock_imports = ["torch"]

# The master toctree document.
master_doc = "index"

Expand Down
6 changes: 3 additions & 3 deletions tests/models/emb/test_deepwalk.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_will_return_computed_embeddings_for_simple_fully_connected_graph():
args = get_args()
model: DeepWalk = DeepWalk.build_model_from_args(args)
graph = Graph(edge_index=(torch.LongTensor([0]), torch.LongTensor([1])))
trained = model.train(graph, creator)
trained = model(graph, creator)
assert len(trained) == 2
np.testing.assert_array_equal(trained[0], embed_1)
np.testing.assert_array_equal(trained[1], embed_2)
Expand All @@ -79,7 +79,7 @@ def test_will_return_computed_embeddings_for_simple_graph():
args = get_args()
model: DeepWalk = DeepWalk.build_model_from_args(args)
graph = Graph(edge_index=(torch.LongTensor([0, 1]), torch.LongTensor([1, 2])))
trained = model.train(graph, creator)
trained = model(graph, creator)
assert len(trained) == 3
np.testing.assert_array_equal(trained[0], embed_1)
np.testing.assert_array_equal(trained[1], embed_2)
Expand All @@ -97,7 +97,7 @@ def creator_mocked(walks, size, window, min_count, sg, workers, iter):
captured_walks_no.append(len(walks))
return creator(walks, size, window, min_count, sg, workers, iter)

model.train(graph, creator_mocked)
model(graph, creator_mocked)
assert captured_walks_no[0] == args.walk_num * graph.num_nodes


Expand Down

0 comments on commit 096f36b

Please sign in to comment.