This repository contains the official implementation of the Efficiently Scaled Attention Interatomic Potential (NeurIPS 2024).
Scaling has been a critical factor in improving model performance and generalization across various fields of machine learning. It involves how a model’s performance changes with increases in model size or input data, as well as how efficiently computational resources are utilized to support this growth. Despite successes in scaling other types of machine learning models, the study of scaling in Neural Network Interatomic Potentials (NNIPs) remains limited. NNIPs act as surrogate models for ab initio quantum mechanical calculations, predicting the energy and forces between atoms in molecules and materials based on atomic configurations. The dominant paradigm in this field is to incorporate numerous physical domain constraints into the model, such as symmetry constraints like rotational equivariance. We contend that these increasingly complex domain constraints inhibit the scaling ability of NNIPs, and such strategies are likely to cause model performance to plateau in the long run. In this work, we take an alternative approach and start by systematically studying NNIP scaling properties and strategies. Our findings indicate that scaling the model through attention mechanisms is both efficient and improves model expressivity. These insights motivate us to develop an NNIP architecture designed for scalability: the Efficiently Scaled Attention Interatomic Potential (EScAIP). EScAIP leverages a novel multi-head self-attention formulation within graph neural networks, applying attention at the neighbor-level representations. Implemented with highly-optimized attention GPU kernels, EScAIP achieves substantial gains in efficiency---at least 10x speed up in inference time, 5x less in memory usage---compared to existing NNIP models. EScAIP also achieves state-of-the-art performance on a wide range of datasets including catalysts (OC20 and OC22), molecules (SPICE), and materials (MPTrj). We emphasize that our approach should be thought of as a philosophy rather than a specific model, representing a proof-of-concept towards developing general-purpose NNIPs that achieve better expressivity through scaling, and continue to scale efficiently with increased computational resources and training data.
Feb 2024 update: We recently discovered that our model relies on a specific older version of the Triton attention kernel for optimal performance, which was recently changed: using newer versions leads to degraded results, especially on the OC20 dataset. We are actively working on tuning the model to be compatible with the latest kernel (see branch kernel_update). We will provide updates on this, and please feel free to reach out if you would like further updates on timelines. If you encounter any issues or you have questions about tuning the model on your datasets of interest, please feel free to open an issue or reach out.
Step 1: Install mamba solver for conda (optional)
conda install mamba -n base -c conda-forge
Step 2: Install the FAIRChem dependencies
wget https://raw.githubusercontent.com/FAIR-Chem/fairchem/main/packages/env.gpu.yml
mamba env create -f env.gpu.yml
conda activate fair-chem
Step 3: Install FairChem core package
git submodule update --init --recursive
pip install -e fairchem/packages/fairchem-core
Step 4: Install the pre-commit (if you want to contribute)
pre-commit install
Use main.py
to train the models. Th usage is similar to FairChem (refer to the FairChem documentation for more details):
Single GPU training:
python main.py --mode train --config-yml {CONFIG} --run-dir {RUNDIR} --timestamp-id {TIMESTAMP} --checkpoint {CHECKPOINT}
Use start_exp.py
to start a run in the background:
python start_exp.py --config-yml {CONFIG} --cvd {GPU NUM} --run-dir {RUNDIR} --timestamp-id {TIMESTAMP} --checkpoint {CHECKPOINT}
Multi-GPU training (same device):
torchrun --standalone --nproc_per_node={N} main.py --distributed --num-gpus {N} {...}
For distributed training on NERSC, refer to the NERSC distributed training with submitit.
Use simulate.py
to simulate the models. It requires to install the MDSim package:
pip install -e MDsim
A separate config file is needed for the simulation. Some example configs can be found in configs/s2ef/MD22/datasets/{DATASET}/simulation.yml
.
python simulate.py --simulation_config_yml {SIMULATION CONFIG} --model_dir {CHECKPOINT DIR} --model_config_yml {MODEL CONFIG}--identifier {IDENTIFIER}
For example, to simulate on MD22 DHA split:
python simulate.py --simulation_config_yml configs/s2ef/MD22/datasets/DHA/simulation.yml --model_dir checkpoints/MD22_DHA/ --model_config_yml configs/s2ef/MD22/EScAIP/DHA.yml --identifier test_simulation
After the simulation, the results can be found in the checkpoint directory under the identifier
folder. Use scripts/analyze_rollouts_md17_22.py
to analyze the results:
PYTHONPATH=./ python scripts/analyze_rollouts_md17_22.py --md_dir checkpoints/MD22_DHA/md_sim_test_simulation --gt_traj /data/md22/md22_AT-AT.npz --xlim 25
Refer to the MDSim repo for more details.
Refer to the model architecture for more details. A detailed description of model configurations can be found here.
Some notes on the configs:
- The model uses
torch.compile
to compile the model for better performance. But it's not supporting second order gradient. So to use gradient energy, it has to be disabled.- The implicit batch needs to be padded to use
torch.compile
. The size is controlled bymax_num_nodes_per_batch
.
- The implicit batch needs to be padded to use
- The attention kernels are selected by
atten_name
, which has 4 options:math
: the default attention by pytorch, supports all datatypes and second order gradient.memory_efficient
: the memory efficient kernel by pytorch, supports fp32 and fp16, but no second order gradient.flash
: the flash attention kernel by pytorch, supports only fp16, no second order gradient.
- The
use_fp16_backbone
will turn the graph attention backbone to fp16, and the output head will still be fp32. For now, we're opting to use AutoMixedPrecision instead, as it's more stable.
We will release the pretrained models in the paper soon (after the model transition is fanalized).
If you find this work useful, please consider citing the following:
@inproceedings{
qu2024the,
title={The Importance of Being Scalable: Improving the Speed and Accuracy of Neural Network Interatomic Potentials Across Chemical Domains},
author={Eric Qu and Aditi S. Krishnapriyan},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
year={2024},
url={https://openreview.net/forum?id=Y4mBaZu4vy}
}