diff --git a/deepsnap/batch.py b/deepsnap/batch.py index f7565d6..eafc0d4 100644 --- a/deepsnap/batch.py +++ b/deepsnap/batch.py @@ -25,7 +25,9 @@ def __init__(self, batch=None, **kwargs): self.__slices__ = None @staticmethod - def collate(follow_batch=[], transform=None, **kwargs): + def collate(follow_batch=None, transform=None, **kwargs): + if follow_batch is None: + follow_batch = [] return lambda batch: Batch.from_data_list( batch, follow_batch, transform, **kwargs ) @@ -70,6 +72,7 @@ def from_data_list( batch, cumsum = Batch._init_batch_fields(keys, follow_batch) batch.__data_class__ = data_list[0].__class__ batch.batch = [] + num_nodes = None for i, data in enumerate(data_list): # Note: in heterogeneous graph, __inc__ logic is different Batch._collate_dict( diff --git a/deepsnap/dataset.py b/deepsnap/dataset.py index 504ac95..6264d36 100644 --- a/deepsnap/dataset.py +++ b/deepsnap/dataset.py @@ -661,8 +661,8 @@ def _split_transductive( Split the dataset assuming training process is transductive. Args: - split_ratio: number of data splitted into train, validation - (and test) set. + split_ratio: ratio of data to be split into train, validation + (and test) sets. Returns: list: A list of 3 (2) lists of :class:`deepsnap.graph.Graph` @@ -978,11 +978,13 @@ def split( transductive: whether the training process is transductive or inductive. Inductive split is always used for graph-level tasks (self.task == 'graph'). - split_ratio: number of data splitted into train, validation - (and test) set. + split_ratio: ratio of data to be split into train, validation + (and test) sets. These ratios must sum to 1. + If `None` the default splits are 0.8, 0.1, + and 0.1 for train, validation, and test sets respectively. Returns: - list: a list of 3 (2) lists of :class:`deepsnap.graph.Graph` + list: a list of 2 (or 3) lists of :class:`deepsnap.graph.Graph` objects corresponding to train, validation (and test) set. """ if self.graphs is None: @@ -1006,8 +1008,8 @@ def split( for split_ratio_i in split_ratio ): raise TypeError("Split ratio must contain all floats.") - if not all(split_ratio_i > 0 for split_ratio_i in split_ratio): - raise ValueError("Split ratio must contain all positivevalues.") + if not all(0 < split_ratio_i < 1 for split_ratio_i in split_ratio): + raise ValueError("Split ratios must be between 0 and 1.") # store the most recent split types self._split_types = split_types @@ -1019,20 +1021,19 @@ def split( graph.edge_label = graph._edge_label # list of num_splits datasets - dataset_return = [] if transductive: if self.task == "graph": raise ValueError( - "in transductive mode, self.task is graph does not " + "in transductive mode, self.task == `graph` does not " "make sense." ) - dataset_return = ( + return ( self._split_transductive( split_ratio, split_types, shuffle=shuffle ) ) else: - dataset_return = ( + return ( self._split_inductive( split_ratio, split_types, @@ -1040,10 +1041,8 @@ def split( ) ) - return dataset_return - def resample_disjoint(self): - r""" Resample disjoint edge split of message passing and objective links. + r"""Resample disjoint edge split of message passing and objective links. Note that if apply_transform (on the message passing graph) was used before this resampling, it needs to be diff --git a/deepsnap/graph.py b/deepsnap/graph.py index 114fb7e..0599e16 100644 --- a/deepsnap/graph.py +++ b/deepsnap/graph.py @@ -9,10 +9,17 @@ from typing import ( Dict, List, + Optional, Union ) import warnings import deepsnap +import networkx as nx +import snap +import snapx as sx + + +GraphType = Union[nx.Graph, sx.Graph] class Graph(object): @@ -23,13 +30,13 @@ class Graph(object): Args: G (:class:`networkx.classes.graph`): The NetworkX graph object which contains features and labels for the tasks. - **kwargs: keyworded argument list with keys such + **kwargs: keyword argument list with keys such as :obj:`"node_feature"`, :obj:`"node_label"` and corresponding attributes. """ - def __init__(self, G=None, netlib=None, **kwargs): - self.G = G + def __init__(self, G: Optional[GraphType] = None, netlib=None, **kwargs): + self.G: nx.Graph = G if netlib is not None: deepsnap._netlib = netlib keys = [ @@ -303,7 +310,8 @@ def get_num_dims(self, key: str, as_label: bool = False) -> int: Returns the number of dimensions for one graph/node/edge property. Args: - as_label: if as_label, treat the tensor as labels ( + key: the `key` to return the dimension for + as_label: if `as_label`, treat the tensor as labels """ if as_label: # treat as label @@ -1147,7 +1155,7 @@ def split( else: raise ValueError("Unknown task.") - def _split_node(self, split_ratio: float, shuffle: bool = True): + def _split_node(self, split_ratio: List[float], shuffle: bool = True): r""" Split the graph into len(split_ratio) graphs for node prediction. Internally this splits node indices, and the model will only compute @@ -1217,7 +1225,7 @@ def _split_node(self, split_ratio: float, shuffle: bool = True): split_graphs.append(graph_new) return split_graphs - def _split_edge(self, split_ratio: float, shuffle: bool = True): + def _split_edge(self, split_ratio: List[float], shuffle: bool = True): r""" Split the graph into len(split_ratio) graphs for node prediction. Internally this splits node indices, and the model will only compute @@ -1393,14 +1401,14 @@ def split_link_pred( nodes in each split graph. This is only used for transductive link prediction task In this task, different part of graph is observed in train/val/test - Note: this functon will be called twice, + Note: this function will be called twice, if during training, we further split the training graph so that message edges and objective edges are different """ if isinstance(split_ratio, float): split_ratio = [split_ratio, 1 - split_ratio] if len(split_ratio) < 2 or len(split_ratio) > 3: - raise ValueError("Unrecoginzed number of splits") + raise ValueError("Unrecognized number of splits") if self.num_edges < len(split_ratio): raise ValueError( "In _split_link_pred num of edges are smaller than" @@ -1526,7 +1534,7 @@ def split_link_pred( else: return [graph_train, graph_val] - def _edge_subgraph_with_isonodes(self, G, edges): + def _edge_subgraph_with_isonodes(self, G: nx.Graph, edges: List): r""" Generate a new networkx graph with same nodes and their attributes. diff --git a/deepsnap/hetero_graph.py b/deepsnap/hetero_graph.py index 2dc5f99..96bfba9 100644 --- a/deepsnap/hetero_graph.py +++ b/deepsnap/hetero_graph.py @@ -8,6 +8,7 @@ from typing import ( Dict, List, + Optional, Union ) import warnings @@ -120,7 +121,7 @@ def message_types(self): """ return list(self["edge_index"].keys()) - def num_nodes(self, node_type: Union[str, List[str]] = None): + def num_nodes(self, node_type: Optional[Union[str, List[str]]] = None) -> Dict: r""" Return number of nodes for a node type or list of node types.