-
Notifications
You must be signed in to change notification settings - Fork 132
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #405 from corochann/add_cgcnn_megnet
Add cgcnn megnet
- Loading branch information
Showing
42 changed files
with
1,755 additions
and
134 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from chainer_chemistry.dataset.converters.concat_mols import concat_mols # NOQA | ||
from chainer_chemistry.dataset.converters.megnet_converter import megnet_converter # NOQA | ||
from chainer_chemistry.dataset.converters.cgcnn_converter import cgcnn_converter # NOQA | ||
|
||
converter_method_dict = { | ||
'ecfp': concat_mols, | ||
'nfp': concat_mols, | ||
'nfp_gwm': concat_mols, | ||
'ggnn': concat_mols, | ||
'ggnn_gwm': concat_mols, | ||
'gin': concat_mols, | ||
'gin_gwm': concat_mols, | ||
'schnet': concat_mols, | ||
'weavenet': concat_mols, | ||
'relgcn': concat_mols, | ||
'rsgcn': concat_mols, | ||
'rsgcn_gwm': concat_mols, | ||
'relgat': concat_mols, | ||
'gnnfilm': concat_mols, | ||
'megnet': megnet_converter, | ||
'cgcnn': cgcnn_converter | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import numpy | ||
|
||
import chainer | ||
from chainer import functions | ||
from chainer.dataset.convert import to_device | ||
|
||
|
||
@chainer.dataset.converter() | ||
def cgcnn_converter(batch, device=None, padding=None): | ||
"""CGCNN converter""" | ||
if len(batch) == 0: | ||
raise ValueError("batch is empty") | ||
|
||
atom_feat, nbr_feat, nbr_idx = [], [], [] | ||
batch_atom_idx, target = [], [] | ||
current_idx = 0 | ||
xp = device.xp | ||
for element in batch: | ||
atom_feat.append(element[0]) | ||
nbr_feat.append(element[1]) | ||
nbr_idx.append(element[2] + current_idx) | ||
target.append(element[3]) | ||
n_atom = element[0].shape[0] | ||
atom_idx = numpy.arange(n_atom) + current_idx | ||
batch_atom_idx.append(atom_idx) | ||
current_idx += n_atom | ||
|
||
atom_feat = to_device(device, functions.concat(atom_feat, axis=0).data) | ||
nbr_feat = to_device(device, functions.concat(nbr_feat, axis=0).data) | ||
# Always use numpy array for batch_atom_index | ||
# this is list of variable length array | ||
batch_atom_idx = numpy.array(batch_atom_idx) | ||
nbr_idx = to_device(device, functions.concat(nbr_idx, axis=0).data) | ||
target = to_device(device, xp.asarray(target)) | ||
result = (atom_feat, nbr_feat, batch_atom_idx, nbr_idx, target) | ||
return result |
138 changes: 69 additions & 69 deletions
138
chainer_chemistry/dataset/converters.py → ...emistry/dataset/converters/concat_mols.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,69 +1,69 @@ | ||
import chainer | ||
|
||
|
||
@chainer.dataset.converter() | ||
def concat_mols(batch, device=None, padding=0): | ||
"""Concatenates a list of molecules into array(s). | ||
This function converts an "array of tuples" into a "tuple of arrays". | ||
Specifically, given a list of examples each of which consists of | ||
a list of elements, this function first makes an array | ||
by taking the element in the same position from each example | ||
and concatenates them along the newly-inserted first axis | ||
(called `batch dimension`) into one array. | ||
It repeats this for all positions and returns the resulting arrays. | ||
The output type depends on the type of examples in ``batch``. | ||
For instance, consider each example consists of two arrays ``(x, y)``. | ||
Then, this function concatenates ``x`` 's into one array, and ``y`` 's | ||
into another array, and returns a tuple of these two arrays. Another | ||
example: consider each example is a dictionary of two entries whose keys | ||
are ``'x'`` and ``'y'``, respectively, and values are arrays. Then, this | ||
function concatenates ``x`` 's into one array, and ``y`` 's into another | ||
array, and returns a dictionary with two entries ``x`` and ``y`` whose | ||
values are the concatenated arrays. | ||
When the arrays to concatenate have different shapes, the behavior depends | ||
on the ``padding`` value. If ``padding`` is ``None``, it raises an error. | ||
Otherwise, it builds an array of the minimum shape that the | ||
contents of all arrays can be substituted to. The padding value is then | ||
used to the extra elements of the resulting arrays. | ||
The current implementation is identical to | ||
:func:`~chainer.dataset.concat_examples` of Chainer, except the default | ||
value of the ``padding`` option is changed to ``0``. | ||
.. admonition:: Example | ||
>>> import numpy | ||
>>> from chainer_chemistry.dataset.converters import concat_mols | ||
>>> x0 = numpy.array([1, 2]) | ||
>>> x1 = numpy.array([4, 5, 6]) | ||
>>> dataset = [x0, x1] | ||
>>> results = concat_mols(dataset) | ||
>>> print(results) | ||
[[1 2 0] | ||
[4 5 6]] | ||
.. seealso:: :func:`chainer.dataset.concat_examples` | ||
Args: | ||
batch (list): | ||
A list of examples. This is typically given by a dataset | ||
iterator. | ||
device (int): | ||
Device ID to which each array is sent. Negative value | ||
indicates the host memory (CPU). If it is omitted, all arrays are | ||
left in the original device. | ||
padding: | ||
Scalar value for extra elements. If this is None (default), | ||
an error is raised on shape mismatch. Otherwise, an array of | ||
minimum dimensionalities that can accommodate all arrays is | ||
created, and elements outside of the examples are padded by this | ||
value. | ||
Returns: | ||
Array, a tuple of arrays, or a dictionary of arrays: | ||
The type depends on the type of each example in the batch. | ||
""" | ||
return chainer.dataset.concat_examples(batch, device, padding=padding) | ||
import chainer | ||
|
||
|
||
@chainer.dataset.converter() | ||
def concat_mols(batch, device=None, padding=0): | ||
"""Concatenates a list of molecules into array(s). | ||
This function converts an "array of tuples" into a "tuple of arrays". | ||
Specifically, given a list of examples each of which consists of | ||
a list of elements, this function first makes an array | ||
by taking the element in the same position from each example | ||
and concatenates them along the newly-inserted first axis | ||
(called `batch dimension`) into one array. | ||
It repeats this for all positions and returns the resulting arrays. | ||
The output type depends on the type of examples in ``batch``. | ||
For instance, consider each example consists of two arrays ``(x, y)``. | ||
Then, this function concatenates ``x`` 's into one array, and ``y`` 's | ||
into another array, and returns a tuple of these two arrays. Another | ||
example: consider each example is a dictionary of two entries whose keys | ||
are ``'x'`` and ``'y'``, respectively, and values are arrays. Then, this | ||
function concatenates ``x`` 's into one array, and ``y`` 's into another | ||
array, and returns a dictionary with two entries ``x`` and ``y`` whose | ||
values are the concatenated arrays. | ||
When the arrays to concatenate have different shapes, the behavior depends | ||
on the ``padding`` value. If ``padding`` is ``None``, it raises an error. | ||
Otherwise, it builds an array of the minimum shape that the | ||
contents of all arrays can be substituted to. The padding value is then | ||
used to the extra elements of the resulting arrays. | ||
The current implementation is identical to | ||
:func:`~chainer.dataset.concat_examples` of Chainer, except the default | ||
value of the ``padding`` option is changed to ``0``. | ||
.. admonition:: Example | ||
>>> import numpy | ||
>>> from chainer_chemistry.dataset.converters import concat_mols | ||
>>> x0 = numpy.array([1, 2]) | ||
>>> x1 = numpy.array([4, 5, 6]) | ||
>>> dataset = [x0, x1] | ||
>>> results = concat_mols(dataset) | ||
>>> print(results) | ||
[[1 2 0] | ||
[4 5 6]] | ||
.. seealso:: :func:`chainer.dataset.concat_examples` | ||
Args: | ||
batch (list): | ||
A list of examples. This is typically given by a dataset | ||
iterator. | ||
device (int): | ||
Device ID to which each array is sent. Negative value | ||
indicates the host memory (CPU). If it is omitted, all arrays are | ||
left in the original device. | ||
padding: | ||
Scalar value for extra elements. If this is None (default), | ||
an error is raised on shape mismatch. Otherwise, an array of | ||
minimum dimensionalities that can accommodate all arrays is | ||
created, and elements outside of the examples are padded by this | ||
value. | ||
Returns: | ||
Array, a tuple of arrays, or a dictionary of arrays: | ||
The type depends on the type of each example in the batch. | ||
""" | ||
return chainer.dataset.concat_examples(batch, device, padding=padding) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import chainer | ||
from chainer.dataset.convert import to_device | ||
|
||
|
||
@chainer.dataset.converter() | ||
def megnet_converter(batch, device=None, padding=0): | ||
"""MEGNet converter""" | ||
if len(batch) == 0: | ||
raise ValueError("batch is empty") | ||
|
||
atom_feat, pair_feat, global_feat, target = [], [], [], [] | ||
atom_idx, pair_idx, start_idx, end_idx = [], [], [], [] | ||
batch_size = len(batch) | ||
current_atom_idx = 0 | ||
for i in range(batch_size): | ||
element = batch[i] | ||
n_atom = element[0].shape[0] | ||
n_pair = element[1].shape[0] | ||
atom_feat.extend(element[0]) | ||
pair_feat.extend(element[1]) | ||
global_feat.append(element[2]) | ||
atom_idx.extend([i]*n_atom) | ||
pair_idx.extend([i]*n_pair) | ||
start_idx.extend(element[3][0] + current_atom_idx) | ||
end_idx.extend(element[3][1] + current_atom_idx) | ||
target.append(element[4]) | ||
current_atom_idx += n_atom | ||
|
||
xp = device.xp | ||
atom_feat = to_device(device, xp.asarray(atom_feat)) | ||
pair_feat = to_device(device, xp.asarray(pair_feat)) | ||
global_feat = to_device(device, xp.asarray(global_feat)) | ||
atom_idx = to_device(device, xp.asarray(atom_idx)) | ||
pair_idx = to_device(device, xp.asarray(pair_idx)) | ||
start_idx = to_device(device, xp.asarray(start_idx)) | ||
end_idx = to_device(device, xp.asarray(end_idx)) | ||
target = to_device(device, xp.asarray(target)) | ||
result = (atom_feat, pair_feat, global_feat, atom_idx, pair_idx, | ||
start_idx, end_idx, target) | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
59 changes: 59 additions & 0 deletions
59
chainer_chemistry/dataset/preprocessors/cgcnn_preprocessor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from logging import getLogger | ||
import numpy | ||
import os | ||
import shutil | ||
|
||
from chainer.dataset import download | ||
|
||
from chainer_chemistry.dataset.utils import GaussianDistance | ||
from chainer_chemistry.dataset.preprocessors.mol_preprocessor import MolPreprocessor # NOQA | ||
from chainer_chemistry.utils import load_json | ||
|
||
download_url = 'https://raw.githubusercontent.com/txie-93/cgcnn/master/data/sample-regression/atom_init.json' # NOQA | ||
file_name_atom_init_json = 'atom_init.json' | ||
|
||
_root = 'pfnet/chainer/cgcnn' | ||
|
||
|
||
def get_atom_init_json_filepath(download_if_not_exist=True): | ||
"""Construct a filepath which stores atom_init_json | ||
This method check whether the file exist or not, and downloaded it if | ||
necessary. | ||
Args: | ||
download_if_not_exist (bool): If `True` download dataset | ||
if it is not downloaded yet. | ||
Returns (str): file path for atom_init_json | ||
""" | ||
cache_root = download.get_dataset_directory(_root) | ||
cache_path = os.path.join(cache_root, file_name_atom_init_json) | ||
if not os.path.exists(cache_path) and download_if_not_exist: | ||
logger = getLogger(__name__) | ||
logger.info('Downloading atom_init.json...') | ||
download_file_path = download.cached_download(download_url) | ||
shutil.copy(download_file_path, cache_path) | ||
return cache_path | ||
|
||
|
||
class CGCNNPreprocessor(MolPreprocessor): | ||
"""CGCNNPreprocessor | ||
Args: | ||
For Molecule: TODO | ||
""" | ||
|
||
def __init__(self, max_num_nbr=12, max_radius=8, expand_dim=40): | ||
super(CGCNNPreprocessor, self).__init__() | ||
|
||
self.max_num_nbr = max_num_nbr | ||
self.max_radius = max_radius | ||
self.gdf = GaussianDistance(centers=numpy.linspace(0, 8, expand_dim)) | ||
feat_dict = load_json(get_atom_init_json_filepath()) | ||
self.atom_features = {int(key): numpy.array(value, | ||
dtype=numpy.float32) | ||
for key, value in feat_dict.items()} | ||
|
||
def get_input_features(self, mol): | ||
raise NotImplementedError() |
Oops, something went wrong.