How do you choose parameters for your DFT calculations?
This package provides machine learning models to predict optimal k-point density (k-dist) for SCF total energy calculations with plane-wave DFT codes for inorganic 3D materials. All models take as input the structure and/or composition of the compound and output k-dist, which is expected to guarantee convergence of total energy calculations while minimizing computational time.
The package implements multiple machine learning approaches for predicting k-point density:
- Graph Neural Networks (GNNs): CGCNN and ALIGNN models that learn from crystal structures
- Transformer Models: CrabNet for composition-based predictions
- Ensemble Methods: Random Forest, Gradient Boosting, and Histogram Gradient Boosting
The models support both regression and classification tasks, with advanced features for uncertainty quantification including robust regression, quantile regression, and conformal prediction.
- CGCNN - Crystal Graph Convolutional Neural Network (paper)
- ALIGNN - Atomistic Line Graph Neural Network (paper)
- CrabNet - Transformer-based model for composition-based predictions (paper)
- Random Forest - Ensemble method with quantile regression support (scikit-learn, sklearn-quantile)
- Gradient Boosting Trees - XGBoost-style gradient boosting with quantile regression support (scikit-learn implementation)
- Histogram Gradient Boosting - Fast gradient boosting implementation
- CGCNN features - Standard atomic embeddings from CGCNN paper
- CGCNN features modified with energy and density cutoff - In addition to CGCNN feature the follwoing features added: 1-hot encoding for energy cutoff, 1-hot encoding for density cutoff, type of pseudopotential. PCA is performed on features to remove dimensions with no infromation content.
- mat2vec features - Mat2vec embeddings were developed by Tshitoyan et al. via skip-gram variation of Word2vec method trained on 3.3 million scientific abstracts, and originally used in CrabNet model
- SOAP features - Calculated for structures with all atoms substituted by one atom type -- not used as was not effective as atomic features
- Matminer composition features - Element property, stoichiometry, and valenceorbital descriptors
- Matminer structure features - Global symmetry and density descriptors
- JarvisCFID features - JARVIS Crystal Fingerprint features, matminer implementation
- SOAP features - Averaged over all atoms in the structure, calculated with DScribe
- CGCNN embeddings - Features extracted from pre-trained CGCNN models. Pre-trained CGCNN model was trained on MP 'is_metal' dataset (Autumn 2025)
- MatSciBert embeddings - Generated from:
- QE SCF input files with k-points section removed, or
- Robocrystallographer structure descriptions
- Radius graph - All atoms within a cutoff radius are considered neighbors
- CrystalNN graph - Uses CrystalNN algorithm to identify nearest neighbors based on chemical environment
- RobustL2 Loss - Gaussian distribution-based robust loss
- RobustL1 Loss - Laplace distribution-based robust loss
- StudentT Loss - Student's t-distribution with configurable degrees of freedom
- Quantile Loss - Single quantile prediction
- IntervalScoreLoss - Interval prediction with coverage guarantees
For GNN models, atomic features are used as input to the graph neural network, and compound-level features are concatenated to the features produced by the GNN encoder. This hybrid approach enables:
- Transfer learning: Leveraging pre-trained models for feature extraction
- Better structure learning: Addressing limitations of GNNs in learning certain structural features
- Domain knowledge integration: Incorporating metallicity and other important predictors
- Python 3.11 or 3.12
- Poetry (for dependency management)
# Clone the repository
git clone https://github.com/stfc/goldilocks_kpoints.git
cd goldilocks_kpoints
# Create python environment
python -m venv .venv
source .venv/bin/activate
# Install dependences of pytorch-geometric as described in [here](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html). It is needed as torch_scatter, torch_sparse should be installed from binary wheels using pip and can't be installed with poetry.
# Install dependencies using Poetry
poetry install
# Activate the virtual environment
poetry shellCreate a CSV file (id_prop.csv) with two columns:
- Column 0: Sample IDs (corresponding to
{id}.ciffiles) - Column 1: Target values (k-dist values)
Place your CIF files in the same directory as the CSV file.
Create or modify a configuration file in configs/ directory. Example configurations:
configs/cgcnn.yaml- For CGCNN modelconfigs/alignn.yaml- For ALIGNN modelconfigs/crabnet.yaml- For CrabNet modelconfigs/ensembles.yaml- For ensemble models
python scripts/train.py --config_file configs/cgcnn.yamlpython scripts/predict.py \
--config_file configs/cgcnn.yaml \
--checkpoint_path trained_models/cgcnn/ \
--output_name output/predictions.csvTo perform Conformalised quantile regression, first train quantile models using quantile loss, or QRF (for ranfom Forest). Then use the notebooks (availibel for RF and ALIGNN, but can be easily modified for other models) to calculate conformal corrections to the intervals.
Models are typically trained for 300 epochs with:
- Early stopping: Monitors validation loss/metrics
- Stochastic Weight Averaging (SWA): Optional, can be enabled
- Classification: Multi-class classification with class weights
- Regression: Standard mean squared error or mean absolute error
- Robust Regression: Estimates aleatoric uncertainty (predicts mean and std)
- Quantile Regression: Predicts specific quantiles or intervals
The dataset and convergence definition were provided by Junwen Yin.
All data was generated with fixed parameters:
- Code: Quantum Espresso
- Pseudopotentials: SSSP1.3_PBESol_efficiency library
- Energy cutoffs: Recommended values for SSSP1.3_PBESol_efficiency
- Smearing: Cold smearing with
degauss=0.01 Ry - Magnetism: All compounds treated as non-magnetic
A calculation is considered converged if the total energy change for 3 consecutive k-meshes with increasing number of points is within 1 meV/atom.
goldilocks_kpoints/
├── configs/ # Configuration files for different models
| ├── cgcnn.yaml
| ├── alignn.yaml
| ├── ensembles.yaml
| └── crabnet.json
├── data/ # Data directory (CIF files, CSV files)
├── trained_models/ # The place to store trained models
├── outputs/ # The place to write outputs to
├── embeddings/
| ├── atom_init_original.json
| ├── atom_init_with_sssp_cutoffs.json
| └── mat2vec.json
├── datamodules/ # PyTorch Lightning data modules
│ ├── gnn_datamodule.py
│ ├── crabnet_datamodule.py
│ └── lmdb_dataset.py
├── models/ # Model implementations
│ ├── cgcnn.py
│ ├── alignn.py
│ ├── crabnet.py
│ ├── ensembles.py
│ └── modelmodule.py
├── utils/ # Utility functions
│ ├── atom_features_utils.py
│ ├── compound_features_utils.py
│ ├── cgcnn_graph.py
| ├── crabnet_utils.py
│ ├── alignn_graph.py
│ └── utils.py
├── scripts/ # Training and prediction scripts
│ ├── train.py
│ └── predict.py
├── notebooks/
| ├── Data-exploration.ipynb
| ├── RF-feature-importance.ipynb
| ├── Surrogate-models.ipynb
| ├── ALIGNN-CQR.ipynb
| ├── RF-CQR.ipynb
| └── Wall-time.ipynb
└── README.md
If you use this code in your research, please cite:
@article{goldilocks_kpoints,
title = {Automatic generation of input files with optimised k-point meshes for Quantum Espresso self-consistent field single point total energy calculations},
author = {Elena Patyukova, Junwen Yin, Susmita Basak, Jaehoon Cha, Alin Elenaa, and Gilberto Teobaldi},
year = {2025},
url = {to be published}
}© 2025 Science and Technology Facilities Council (STFC)
This project is licensed under the Creative Commons Attribution 4.0 International License (CC BY 4.0).