-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdata.py
69 lines (54 loc) · 1.78 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from itertools import count
from collections import defaultdict as ddict
import numpy as np
import torch as th
def parse_seperator(line, length, sep='\t'):
d = line.strip().split(sep)
if len(d) == length:
w = 1
elif len(d) == length + 1:
w = int(d[-1])
d = d[:-1]
else:
raise RuntimeError(f'Malformed input ({line.strip()})')
return tuple(d) + (w,)
def parse_tsv(line, length=2):
return parse_seperator(line, length, '\t')
def parse_space(line, length=2):
return parse_seperator(line, length, ' ')
def iter_line(fname, fparse, length=2, comment='#'):
with open(fname, 'r') as fin:
for line in fin:
if line[0] == comment:
continue
tpl = fparse(line, length=length)
if tpl is not None:
yield tpl
def intmap_to_list(d):
arr = [None for _ in range(len(d))]
for v, i in d.items():
arr[i] = v
assert not any(x is None for x in arr)
return arr
def slurp(fin, fparse=parse_tsv, symmetrize=False):
ecount = count()
enames = ddict(ecount.__next__)
subs = []
for i, j, w in iter_line(fin, fparse, length=2):
if i == j:
continue
subs.append((enames[i], enames[j], w))
if symmetrize:
subs.append((enames[j], enames[i], w))
idx = th.from_numpy(np.array(subs, dtype=np.int))
# freeze defaultdicts after training data and convert to arrays
objects = intmap_to_list(dict(enames))
print(f'slurp: objects={len(objects)}, edges={len(idx)}')
return idx, objects