-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhypergraph_batch.py
151 lines (117 loc) · 6.16 KB
/
hypergraph_batch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import jax.numpy as jnp
import numpy as np
from scipy.sparse import csr_array
from hypergraph import HyperGraph
def hypergraph_batch(hypergraphs: list[HyperGraph]) -> HyperGraph:
r"""
This function takes as input a list of Hypergraph instances, and
contatenates them into a single HyperGraph instance for batching purposes.
It is assumed that all hypergraphs in the list are comparable, in the sense
that node and hedge features are of the same length and type. It is the
responsibility of the caller to ensure that this is the case.
"""
n_hgraphs = len(hypergraphs)
# we loop over each hypergraph in the list and create new
# lists for creating the corresponding super-hypergraph with
# correlative node and hedge indices
b_hedges = np.array([], dtype=np.int32)
b_hedge_receivers = np.array([], np.int32)
b_hedge_senders = np.array([], np.int32)
b_hedge2node_receivers = np.array([], dtype=np.int32)
b_hedge2node_senders = np.array([], dtype=np.int32)
b_nodes = np.array([], dtype=np.int32)
b_node_receivers = np.array([], dtype=np.int32)
b_node_senders = np.array([], dtype=np.int32)
b_node2hedge_receivers = np.array([], dtype=np.int32)
b_node2hedge_senders = np.array([], dtype=np.int32)
node_features = []
hedge_features = []
hedge_convolution = []
hedge2node_convolution = []
node_convolution = []
node2hedge_convolution = []
weights = []
# targets is a dictionary with the same keys as individual hypergraphs
# (assumed to have all the same keys); the values for a batch of hypergraphs
# will be arrays of the same length as the batch
targets = {}
node_index = []
hedge_index = []
# batch_node_index and batch_hedge_index are 1d vectors that map nodes an
# hedges (respectively) in a super-graph to the original hypergraph
# from which they came; it allows to map individual graphs
# to the right target when fitting the model
n_hedges = 0
n_nodes = 0
for n_graph, hgraph in enumerate(hypergraphs):
node_index += hgraph.n_nodes*[n_graph]
hedge_index += hgraph.n_hedges*[n_graph]
hedges = np.asarray(hgraph.incidence[0,:]) + n_hedges
b_hedges = np.concatenate((b_hedges, hedges))
hedge_receivers = np.asarray(hgraph.hedge_receivers) + n_hedges
b_hedge_receivers = np.concatenate((b_hedge_receivers,
hedge_receivers))
hedge_senders = np.asarray(hgraph.hedge_senders) + n_hedges
b_hedge_senders = np.concatenate((b_hedge_senders, hedge_senders))
hedge2node_receivers = np.asarray(hgraph.hedge2node_receivers) + \
n_nodes
b_hedge2node_receivers = np.concatenate((b_hedge2node_receivers,
hedge2node_receivers))
hedge2node_senders = np.asarray(hgraph.hedge2node_senders) + n_hedges
b_hedge2node_senders = np.concatenate((b_hedge2node_senders,
hedge2node_senders))
nodes = np.asarray(hgraph.incidence[1,:]) + n_nodes
b_nodes = np.concatenate((b_nodes, nodes))
node_receivers = np.asarray(hgraph.node_receivers) + n_nodes
b_node_receivers = np.concatenate((b_node_receivers, node_receivers))
node_senders = np.asarray(hgraph.node_senders) + n_nodes
b_node_senders = np.concatenate((b_node_senders, node_senders))
node2hedge_receivers = np.asarray(hgraph.node2hedge_receivers) + \
n_hedges
b_node2hedge_receivers = np.concatenate((b_node2hedge_receivers,
node2hedge_receivers))
node2hedge_senders = np.asarray(hgraph.node2hedge_senders) + n_nodes
b_node2hedge_senders = np.concatenate((b_node2hedge_senders,
node2hedge_senders))
# these must be concatenated at the end of the loop
node_features.append(hgraph.node_features)
hedge_features.append(hgraph.hedge_features)
hedge_convolution.append(hgraph.hedge_convolution)
hedge2node_convolution.append(hgraph.hedge2node_convolution)
node_convolution.append(hgraph.node_convolution)
node2hedge_convolution.append(hgraph.node2hedge_convolution)
weights.append(hgraph.weights)
for key in hgraph.targets.keys():
value = hgraph.targets[key]
batch_value = targets.get(key, [])
batch_value.append(value)
targets[key] = batch_value
n_hedges += hgraph.n_hedges
n_nodes += hgraph.n_nodes
# now we can create a batch hgraph (concatenation of individual hgraphs)
# with the accummulated arrays
# create an "empty" hypergraph and populate it
batch_hgraph = HyperGraph(incidence = None)
batch_hgraph.n_hedges = n_hedges
batch_hgraph.n_nodes = n_nodes
batch_hgraph.batch_node_index = jnp.array(node_index)
batch_hgraph.batch_hedge_index = jnp.array(hedge_index)
incidence = jnp.array([b_hedges, b_nodes])
batch_hgraph.incidence = incidence
batch_hgraph.node_features = jnp.concatenate(node_features)
batch_hgraph.hedge_features = jnp.concatenate(hedge_features)
batch_hgraph.hedge_convolution = jnp.concatenate(hedge_convolution)
batch_hgraph.hedge2node_convolution = jnp.concatenate(hedge2node_convolution)
batch_hgraph.node_convolution = jnp.concatenate(node_convolution)
batch_hgraph.node2hedge_convolution = jnp.concatenate(node2hedge_convolution)
batch_hgraph.weights = jnp.concatenate(weights)
batch_hgraph.targets = targets
batch_hgraph.node_receivers = jnp.array(b_node_receivers)
batch_hgraph.node_senders = jnp.array(b_node_senders)
batch_hgraph.node2hedge_receivers = jnp.array(b_node2hedge_receivers)
batch_hgraph.node2hedge_senders = jnp.array(b_node2hedge_senders)
batch_hgraph.hedge_receivers = jnp.array(b_hedge_receivers)
batch_hgraph.hedge_senders = jnp.array(b_hedge_senders)
batch_hgraph.hedge2node_receivers = jnp.array(b_hedge2node_receivers)
batch_hgraph.hedge2node_senders = jnp.array(b_hedge2node_senders)
return batch_hgraph