This repository contains the official implementation of the work "AGDIFF: Attention-Enhanced Diffusion for Molecular Geometry Prediction".
AGDIFF introduces a novel approach that enhances diffusion models with attention mechanisms and an improved SchNet architecture, achieving state-of-the-art performance in predicting molecular geometries.
- Attention Mechanisms: Enhances the global and local encoders with attention mechanisms for better feature extraction and integration.
- Improved SchNet Architecture: Incorporates learnable activation functions, adaptive scaling modules, and dual pathway processing to increase model expressiveness.
- Batch Normalization: Stabilizes training and improves convergence for the local encoder.
- Feature Expansion: Extends the MLP Edge Encoder with feature expansion and processing, combining processed features and bond embeddings for more adaptable edge representations.
generation.mp4
conda env create -f agdiff.yml
conda activate agdiff
pip install torch_geometric
pip install torch-scatter -f https://data.pyg.org/whl/torch-2.4.0+cu121.html
pip install torch-sparse -f https://data.pyg.org/whl/torch-2.4.0+cu121.html
pip install torch-cluster -f https://data.pyg.org/whl/torch-2.4.0+cu121.html
Once you installed all the dependencies, you should install the package locally in editable mode:
pip install -e .
The preprocessed datasets (GEOM) provided by GEODIFF can be found in this [Google Drive folder]. After downloading and unzipping the dataset, it should be placed in the folder path specified by the dataset
variable in the configuration files located at ./configs/*.yml
. You may also want to use the pretrained model provided in the same link.
The official raw GEOM dataset is also available [here].
AGDIFF's training details and hyper-parameters are provided in the config files (./configs/*.yml
). Feel free to tune these parameters as needed.
To train the model, use the following commands:
python scripts/train.py ./configs/qm9_default.yml
python scripts/train.py ./configs/drugs_default.yml
Model checkpoints, configuration YAML files, and training logs will be saved in a directory specified by --logdir
in train.py
.
To generate conformations for entire or part of test sets, use:
python scripts/test.py ./logs/path/to/checkpoints/${iter}.pt ./configs/qm9_default.yml \
--start_idx 0 --end_idx 200
Here start_idx
and end_idx
indicate the range of the test set that we want to use. To reproduce the paper's results, you should use 0 and 200 for start_idx and end_idx, respectively. All hyper-parameters related to sampling can be set in test.py
files. Specifically, for testing the qm9 model, you could add the additional arg --w_global 0.3
, which empirically shows slightly better results.
We also provide an example of conformation generation for a specific molecule (alanine dipeptide) in the examples
folder. To generate conformations for alanine dipeptide, use:
python examples/test_alanine_dipeptide.py ./logs/path/to/checkpoints/${iter}.pt ./configs/qm9_default.yml ./examples/alanine_dipeptide.pdb
After generating conformations, evaluate the results of benchmark tasks using the following commands.
Calculate COV
and MAT
scores on the GEOM datasets with:
python scripts/evaluation/eval_covmat.py path/to/samples/sample_all.pkl
Our implementation is based on GEODIFF, PyTorch, PyG, SchNet
If you use our code or method in your work, please consider citing the following:
@misc{wyzykowskiAGDIFFAttentionEnhancedDiffusion2024,
title = {{{AGDIFF}}: {{Attention-Enhanced Diffusion}} for {{Molecular Geometry Prediction}}},
shorttitle = {{{AGDIFF}}},
author = {Wyzykowski, Andr{\'e} Brasil Vieira and Fathi Niazi, Fatemeh and Dickson, Alex},
year = {2024},
month = oct,
publisher = {ChemRxiv},
doi = {10.26434/chemrxiv-2024-wrvr4},
urldate = {2024-10-09},
archiveprefix = {ChemRxiv},
langid = {english},
keywords = {attention,conformer,diffusion models,generative,GNN,graph neural network,machine learning,structure}
}
Please direct any questions to André Wyzykowski ([email protected]) and Alex Dickson ([email protected]).