-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhypergraph_dataset.py
122 lines (82 loc) · 3.27 KB
/
hypergraph_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from typing import Optional
from glob import glob
import pickle
import random
from torch.utils.data import Dataset
from hypergraph import HyperGraph
class HyperGraphDataSet(Dataset):
"""
:Class:
Data set class to load molecular hypergraph data
"""
def __init__(
self,
database_dir: str = None,
nMaxEntries: int = None,
seed: int = 42,
file_extension: str = '.pkl',
files: Optional[list[str]] = None
) -> None:
"""
Args:
:param str database_dir: the directory where the data files reside
:param int nMaxEntries: optionally used to limit the number of clusters
to consider; default is all
:param int seed: initialises the random seed for choosing randomly
which data files to consider; the default ensures the
same sequence is used for the same number of files in
different runs
:param str file_extension: the extension of files in the database; default = .xyz
:files list[str] (optional): by default, with database_dir and file_extension this
class constructs a dataset with all the files (or nMaxEntries if not None)
contained in database_dir. It might be desirable however to split the files
into two datasets (e.g. training and validation); in this case, the user
must provide a list of filenames (full relative path).
"""
self.database_dir = database_dir
if files is None:
filenames = database_dir + "/*"+file_extension
files = glob(filenames)
self.n_structures = len(files)
"""
files contains a list of files, one for each item in
the database if nMaxEntries != None and is set to some integer
value less than n_structures, then nMaxEntries clusters are
selected randomly for use.
"""
if nMaxEntries and nMaxEntries < self.n_structures:
self.n_structures = nMaxEntries
random.seed(seed)
self.filenames = random.sample(files, nMaxEntries)
else:
self.n_structures = len(files)
self.filenames = files
def __len__(self) -> int:
"""
:return: the number of entries in the database
:rtype: int
"""
return self.n_structures
def __getitem__(self, idx: int) -> HyperGraph:
"""
This function loads from file the corresponding data for entry
idx in the database and returns the corresponding graph read
from the file
Args:
:param int idx: the idx'th entry in the database
:return: the idx'th graph in the database
:rtype: HyperGraph
"""
file_name = self.filenames[idx]
with open(file_name, 'rb') as infile:
hgraph = pickle.load(infile)
return hgraph
def get_filename(self, idx: int) -> str:
"""
Returns the cluster data file name
:param int idx: the idx'th entry in the database
:return: the filename containing the structure data corresponding
to the idx'th entry in the database
:rtype: str
"""
return self.filenames[idx]