diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3c0af59 --- /dev/null +++ b/.gitignore @@ -0,0 +1,210 @@ +# File created using '.gitignore Generator' for Visual Studio Code: https://bit.ly/vscode-gig +# Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,linux,python +# Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,linux,python + +### Linux ### +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +### VisualStudioCode ### +.vscode/* + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix + +### VisualStudioCode Patch ### +# Ignore all local history of files +.history +.ionide + +# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,linux,python + +# Custom rules (everything added below won't be overriden by 'Generate .gitignore File' if you use 'Update' option) +results +resources \ No newline at end of file diff --git a/README.md b/README.md index 244a3e3..176880d 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,150 @@ -# TPP_MOL_DPO +# Target Product Profile-Guided Drug Design Using Multi-objective Direct Preference Optimization -Data and code will be available soon. + +Sejeong Park\*, Jungwoo Park\*, Donghyeon Lee, Sunkyu Kim, Jaewoo Kang (\* indicates equal contribution) + +
./figures/figure-1.png
+ +Abstarct: *Fragment-Based Drug Design (FBDD) offers unique advantages in exploring chemical space and optimizing lead compounds. However, existing FBDD approaches often struggle to efficiently meet complex Target Product Profile (TPP) requirements due to the significant time and financial investments typically required in traditional workflows. This study introduces a novel TPP-guided fragment-based generative model that integrates Direct Preference Optimization (DPO) with sequence-based fragment generation. The model transforms multiobjective optimization into a preference learning task, simultaneously optimizing multiple molecular properties from the earliest design stages. Key innovations include In-Batch DPO for computational efficiency and a multi-objective learning strategy balancing diverse molecular properties. Case studies across various therapeutic targets demonstrate significant improvements in generating diverse drug candidates optimized for binding affinity, synthetic accessibility, druglikeness and ADMET properties, potentially accelerating the discovery of novel therapeutics for challenging targets.* + +This repository contains the code to reproduce the experiments in the paper. + +## Prerequisite + +All experiments in this work were conducted on [TPU-v3-8](https://cloud.google.com/tpu/docs/v3). For research purposes, you can apply to [the TRC progam](https://sites.research.google/trc/about/) [here](https://sites.research.google/trc/about/) to receive free TPU quota. To create a TPU VM instance, run the command below: +```bash +$ gcloud compute tpus tpu-vm create tpu-name \ + --zone=europe-west4-a \ + --accelerator-type=v3-8 \ + --version=tpu-vm-base +``` +Now you can access the TPU VM through SSH: +```bash +gcloud compute tpus tpu-vm ssh tpu-name --zone=europe-west4-a +``` + +## Requirements +After preparing TPU instances, install the conda environment. +```bash +$ wget https://repo.anaconda.com/miniconda/Miniconda3-py310_24.1.2-0-Linux-x86_64.sh +$ bash Miniconda3-py310_24.1.2-0-Linux-x86_64.sh -b -u +``` +And then, install the requirements via pip: +```bash +$ pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +$ pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu +$ pip install -U flax optax chex webdataset wandb fsspec gcsfs transformers sentencepiece tiktoken omegaconf safe-mol pandas==2.0.0 admet-ai +``` + +## Getting Started + +### Prepare a Preference Dataset + +To support various molecule generation methods, such as de novo and scaffold decoration, we use the [SAFE-GPT](https://huggingface.co/datamol-io/safe-gpt) model with [SAFE](https://arxiv.org/abs/2310.10773) representation. +To apply preference optimization to the model, we first generate molecules and then evaluate their molecular properties. +Unlike conventional preference optimization methods, we only need the molecular properties instead of direct comparison or other preference signals. + +We found that de novo generated molecules suffice to optimize the model's entire chemical space, thereby also improving molecules for scaffold decoration and scaffold morphing. +You can follow [this tutorial](https://safe-docs.datamol.io/stable/tutorials/design-with-safe.html#de-novo-generation) to generate molecules yourself or use the command below: +```bash +$ python scripts/generate_de_novo_samples.py \ + --num-samples 100000 \ + --batch-size 256 \ + --max-length 128 \ + --output safe-de-novo-dataset.csv +``` +This script automatically evaluates the necessary properties, including SA score, QED, and ADMET scores from [ADMET-AI](https://github.com/swansonk14/admet_ai). We recommend running this script on a GPU environment for faster generation. + +If you want to optimize the docking score as well, run [AutoDock-GPU](https://github.com/ccsb-scripps/AutoDock-GPU) to calculate the binding affinity of each molecule and then combine it with the training dataset. +Refer to [scripts/convert_smiles_to_pdbqt.py](scripts/convert_smiles_to_pdbqt.py) for SMILES-to-PDBQT conversion, and then run AutoDock with your target protein. +After running the AutoDock program, you can find the `.dlg` output file containing: +``` +AutoDock-GPU version: v1.5.3-73-gf5cf6ffdd0c5b3f113d5cc424fabee51df04da7e + +********************************************************** +** AutoDock-GPU AUTODOCKTOOLS-COMPATIBLE DLG FILE ** +********************************************************** + +[...] + + RMSD TABLE + __________ + + +_______________________________________________________________________ + | | | | | | +Rank | Sub- | Run | Binding | Cluster | Reference | Grep + | Rank | | Energy | RMSD | RMSD | Pattern +_____|______|______|___________|_________|_________________|___________ + 1 1 19 -6.63 0.00 13.74 RANKING + 1 2 6 -6.63 0.11 13.79 RANKING + 1 3 12 -6.62 0.09 13.76 RANKING + 1 4 14 -6.61 0.23 13.79 RANKING + 1 5 10 -6.60 0.19 13.78 RANKING + 1 6 5 -6.44 0.57 13.69 RANKING + 1 7 2 -6.44 0.60 13.73 RANKING + 1 8 11 -6.43 0.58 13.75 RANKING + 1 9 20 -6.42 0.56 13.78 RANKING + 1 10 15 -6.36 0.58 13.72 RANKING + 1 11 13 -6.22 0.62 13.82 RANKING + 1 12 18 -5.85 0.79 13.57 RANKING + 2 1 9 -6.18 0.00 11.90 RANKING + 2 2 16 -5.90 0.93 11.71 RANKING + 2 3 8 -5.90 0.94 11.72 RANKING + 2 4 17 -5.85 0.97 11.67 RANKING + 3 1 4 -5.73 0.00 11.13 RANKING + 3 2 1 -5.60 1.52 11.86 RANKING + 3 3 3 -5.54 1.54 11.67 RANKING + 4 1 7 -5.53 0.00 11.92 RANKING + +Run time 0.225 sec +Idle time 0.161 sec +``` +Choose the highest-ranked binding energy as a docking score and merge it into the training dataset with a column name, e.g., `DS_7O2I`. + +### Optimize Molecular Properties +Now we can optimize the SAFE-GPT model with our preference dataset. +There are three configuration presets: +- [config/safe-dpo-simple-20ep.sh](config/safe-dpo-simple-20ep.sh): Simple averaging for multi-objective preference optimization. +- [config/safe-dpo-moco-20ep.sh](config/safe-dpo-moco-20ep.sh): Balanced multi-objective preference optimization. +- [config/safe-dpo-moco-20ep-pref.sh](config/safe-dpo-moco-20ep-pref.sh): Balanced multi-objective preference optimization with user preferences. + +Using these presets, you can run an experiment with various senarios and property combinations: +```bash +### Simple Averaging ### +bash config/safe-dpo-simple-20ep.sh safe-dpo-simple-20ep-8P1Q_SAScore_QED ./safe-dpo-full-dataset-94k.csv "DS_8P1Q:min:5 SAScore:min:1 QED:max:1" +bash config/safe-dpo-simple-20ep.sh safe-dpo-simple-20ep-8P1Q_hERG_CYPs ./safe-dpo-full-dataset-94k.csv "DS_8P1Q:min:5 hERG:min:0 CYP1A2_Veith:min:0 CYP2C9_Veith:min:0 CYP2C19_Veith:min:0 CYP2D6_Veith:min:0 CYP3A4_Veith:min:0" +bash config/safe-dpo-simple-20ep.sh safe-dpo-simple-20ep-8P1Q_SAScore_QED_CYPs ./safe-dpo-full-dataset-94k.csv "DS_8P1Q:min:5 SAScore:min:1 QED:max:1 hERG:min:0" +bash config/safe-dpo-simple-20ep.sh safe-dpo-simple-20ep-8P1Q_SAScore_QED_hERG ./safe-dpo-full-dataset-94k.csv "DS_8P1Q:min:5 SAScore:min:1 QED:max:1 CYP1A2_Veith:min:0 CYP2C9_Veith:min:0 CYP2C19_Veith:min:0 CYP2D6_Veith:min:0 CYP3A4_Veith:min:0" +bash config/safe-dpo-simple-20ep.sh safe-dpo-simple-20ep-8P1Q_SAScore_QED_hERG_CYPs ./safe-dpo-full-dataset-94k.csv "DS_8P1Q:min:5 SAScore:min:1 QED:max:1 hERG:min:0 CYP1A2_Veith:min:0 CYP2C9_Veith:min:0 CYP2C19_Veith:min:0 CYP2D6_Veith:min:0 CYP3A4_Veith:min:0" + +### Balanced ### +bash config/safe-dpo-moco-20ep.sh safe-dpo-moco-nopref-20ep-8P1Q_SAScore_QED ./safe-dpo-full-dataset-94k.csv "DS_8P1Q:min:5 SAScore:min:1 QED:max:1" +bash config/safe-dpo-moco-20ep.sh safe-dpo-moco-nopref-20ep-8P1Q_hERG_CYPs ./safe-dpo-full-dataset-94k.csv "DS_8P1Q:min:5 hERG:min:0 CYP1A2_Veith:min:0 CYP2C9_Veith:min:0 CYP2C19_Veith:min:0 CYP2D6_Veith:min:0 CYP3A4_Veith:min:0" +bash config/safe-dpo-moco-20ep.sh safe-dpo-moco-nopref-20ep-8P1Q_SAScore_QED_CYPs ./safe-dpo-full-dataset-94k.csv "DS_8P1Q:min:5 SAScore:min:1 QED:max:1 hERG:min:0" +bash config/safe-dpo-moco-20ep.sh safe-dpo-moco-nopref-20ep-8P1Q_SAScore_QED_hERG ./safe-dpo-full-dataset-94k.csv "DS_8P1Q:min:5 SAScore:min:1 QED:max:1 CYP1A2_Veith:min:0 CYP2C9_Veith:min:0 CYP2C19_Veith:min:0 CYP2D6_Veith:min:0 CYP3A4_Veith:min:0" +bash config/safe-dpo-moco-20ep.sh safe-dpo-moco-nopref-20ep-8P1Q_SAScore_QED_hERG_CYPs ./safe-dpo-full-dataset-94k.csv "DS_8P1Q:min:5 SAScore:min:1 QED:max:1 hERG:min:0 CYP1A2_Veith:min:0 CYP2C9_Veith:min:0 CYP2C19_Veith:min:0 CYP2D6_Veith:min:0 CYP3A4_Veith:min:0" + +### Balanced w/ Preferences ### +bash config/safe-dpo-moco-20ep-pref.sh safe-dpo-moco-20ep-8P1Q_SAScore_QED ./safe-dpo-full-dataset-94k.csv "DS_8P1Q:min:5 SAScore:min:1 QED:max:1" +bash config/safe-dpo-moco-20ep-pref.sh safe-dpo-moco-20ep-8P1Q_hERG_CYPs ./safe-dpo-full-dataset-94k.csv "DS_8P1Q:min:5 hERG:min:0 CYP1A2_Veith:min:0 CYP2C9_Veith:min:0 CYP2C19_Veith:min:0 CYP2D6_Veith:min:0 CYP3A4_Veith:min:0" +bash config/safe-dpo-moco-20ep-pref.sh safe-dpo-moco-20ep-8P1Q_SAScore_QED_CYPs ./safe-dpo-full-dataset-94k.csv "DS_8P1Q:min:5 SAScore:min:1 QED:max:1 hERG:min:0" +bash config/safe-dpo-moco-20ep-pref.sh safe-dpo-moco-20ep-8P1Q_SAScore_QED_hERG ./safe-dpo-full-dataset-94k.csv "DS_8P1Q:min:5 SAScore:min:1 QED:max:1 CYP1A2_Veith:min:0 CYP2C9_Veith:min:0 CYP2C19_Veith:min:0 CYP2D6_Veith:min:0 CYP3A4_Veith:min:0" +bash config/safe-dpo-moco-20ep-pref.sh safe-dpo-moco-20ep-8P1Q_SAScore_QED_hERG_CYPs ./safe-dpo-full-dataset-94k.csv "DS_8P1Q:min:5 SAScore:min:1 QED:max:1 hERG:min:0 CYP1A2_Veith:min:0 CYP2C9_Veith:min:0 CYP2C19_Veith:min:0 CYP2D6_Veith:min:0 CYP3A4_Veith:min:0" +``` +Each script requires three arguments. +- The first argument is the experiment name. +- The second argument is the path to the training dataset constructed in the section above. +- The third argument specifies the target molecular properties in the format: `[column]:[min/max]:[pref]`. + +For instance, `DS_8P1Q:min:5 SAScore:min:1 QED:max:1 hERG:min:0` minimizes the docking score of protein 8P1Q, the SA score, and the hERG channel blocking, while maximizing the QED drug-likeness score. +With the preset [config/safe-dpo-moco-20ep-pref.sh](config/safe-dpo-moco-20ep-pref.sh), the balanced loss weights will focus more on the docking score since its preference strength is 5. + +## Citation +```bibtex +@misc{park2024tppmoldpo, + title={Target Product Profile-Guided Drug Design Using Multi-objective Direct Preference Optimization}, + author={Sejeong Park, Jungwoo Park, Donghyeon Lee, Sunkyu Kim, Jaewoo Kang}, + year={2024}, +} +``` diff --git a/config/safe-dpo-moco-20ep-pref.sh b/config/safe-dpo-moco-20ep-pref.sh new file mode 100644 index 0000000..a543d21 --- /dev/null +++ b/config/safe-dpo-moco-20ep-pref.sh @@ -0,0 +1,30 @@ +python3 src/main.py \ + --dataset $2 \ + --batch-size 128 \ + --num-workers 32 \ + --target-columns $3 \ + --max-length 128 \ + --penalty-beta 0.1 \ + --eval-metrics SAScore QED logP plogP Validity Uniqueness IntDiv hERG CYP1A2_Veith CYP2C9_Veith CYP2C19_Veith CYP2D6_Veith CYP3A4_Veith \ + --use-moco \ + --jacmom 0.99 \ + --lammom 0.5 \ + --lamreg 0.5 \ + --learning-rate 5e-6 \ + --weight-decay 0.01 \ + --adam-b1 0.9 \ + --adam-b2 0.999 \ + --adam-eps 1e-8 \ + --clip-grad 1.0 \ + --warmup-ratio 0.1 \ + --epochs 20 \ + --log-interval 50 \ + --eval-interval 2 \ + --eval-batches 4 \ + --split-seed 0 \ + --shuffle-seed 0 \ + --project chemgpt-pref-opt \ + --name $1 \ + --ipaddr $(curl -s ifconfig.me) \ + --hostname $(hostname) \ + --output-dir ./results/ diff --git a/config/safe-dpo-moco-20ep.sh b/config/safe-dpo-moco-20ep.sh new file mode 100644 index 0000000..6cc9d0d --- /dev/null +++ b/config/safe-dpo-moco-20ep.sh @@ -0,0 +1,30 @@ +python3 src/main.py \ + --dataset $2 \ + --batch-size 128 \ + --num-workers 32 \ + --target-columns $3 \ + --max-length 128 \ + --penalty-beta 0.1 \ + --eval-metrics SAScore QED logP plogP Validity Uniqueness IntDiv hERG CYP1A2_Veith CYP2C9_Veith CYP2C19_Veith CYP2D6_Veith CYP3A4_Veith \ + --use-moco \ + --jacmom 0.99 \ + --lammom 0.5 \ + --lamreg 0.0 \ + --learning-rate 5e-6 \ + --weight-decay 0.01 \ + --adam-b1 0.9 \ + --adam-b2 0.999 \ + --adam-eps 1e-8 \ + --clip-grad 1.0 \ + --warmup-ratio 0.1 \ + --epochs 20 \ + --log-interval 50 \ + --eval-interval 2 \ + --eval-batches 4 \ + --split-seed 0 \ + --shuffle-seed 0 \ + --project chemgpt-pref-opt \ + --name $1 \ + --ipaddr $(curl -s ifconfig.me) \ + --hostname $(hostname) \ + --output-dir ./results/ diff --git a/config/safe-dpo-simple-20ep.sh b/config/safe-dpo-simple-20ep.sh new file mode 100644 index 0000000..ee3af2b --- /dev/null +++ b/config/safe-dpo-simple-20ep.sh @@ -0,0 +1,29 @@ +python3 src/main.py \ + --dataset $2 \ + --batch-size 128 \ + --num-workers 32 \ + --target-columns $3 \ + --max-length 128 \ + --penalty-beta 0.1 \ + --eval-metrics SAScore QED logP plogP Validity Uniqueness IntDiv hERG CYP1A2_Veith CYP2C9_Veith CYP2C19_Veith CYP2D6_Veith CYP3A4_Veith \ + --jacmom 0.99 \ + --lammom 0.5 \ + --lamreg 0.0 \ + --learning-rate 5e-6 \ + --weight-decay 0.01 \ + --adam-b1 0.9 \ + --adam-b2 0.999 \ + --adam-eps 1e-8 \ + --clip-grad 1.0 \ + --warmup-ratio 0.1 \ + --epochs 20 \ + --log-interval 50 \ + --eval-interval 2 \ + --eval-batches 4 \ + --split-seed 0 \ + --shuffle-seed 0 \ + --project chemgpt-pref-opt \ + --name $1 \ + --ipaddr $(curl -s ifconfig.me) \ + --hostname $(hostname) \ + --output-dir ./results/ diff --git a/figures/figure-1.png b/figures/figure-1.png new file mode 100644 index 0000000..ed4c638 Binary files /dev/null and b/figures/figure-1.png differ diff --git a/scripts/convert_smiles_to_pdbqt.py b/scripts/convert_smiles_to_pdbqt.py new file mode 100644 index 0000000..d87dd08 --- /dev/null +++ b/scripts/convert_smiles_to_pdbqt.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import argparse +import multiprocessing as mp +import os +import shutil +from functools import partial + +import pandas as pd +import tqdm +from rdkit import Chem +from rdkit.Chem import rdDistGeom, rdForceFieldHelpers + + +def convert_smiles_to_pdbqt(args, output_dir: str = "./"): + try: + with Chem.SDWriter("{}/ligand_{}.sdf".format(output_dir, args[0])) as writer: + mol = Chem.MolFromSmiles(args[1]) + mol = Chem.AddHs(mol) + etkdgv3 = rdDistGeom.ETKDGv3() + rdDistGeom.EmbedMolecule(mol, etkdgv3) + + try: + rdForceFieldHelpers.UFFOptimizeMolecule(mol) + except Exception: + print(f"{args[1]} UFF optimization failed") + + if mol is not None: + writer.write(mol) + + os.system( + f"mk_prepare_ligand.py" + f" -i {output_dir}/ligand_{args[0]}.sdf" + f" -o {output_dir}/ligand_{args[0]}.pdbqt" + ) + except Exception: + print("ligand_prep failed: {}".format(args[1])) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dataset", default="smiles.csv") + parser.add_argument("--output-dir", default="ligands") + parser.add_argument("--use-index", action="store_true", default=False) + parser.add_argument("--subset", default="0:1") + args = parser.parse_args() + + shutil.rmtree(args.output_dir, ignore_errors=True) + os.makedirs(args.output_dir, exist_ok=True) + + subset_idx, subset_cnt = map(int, args.subset.split(":")) + + dataset = pd.read_csv(args.dataset, index_col=0 if args.use_index else None) + dataset = dataset.iloc[subset_idx::subset_cnt] + with mp.Pool() as pool: + convert_fn = partial(convert_smiles_to_pdbqt, output_dir=args.output_dir) + it = pool.imap_unordered(convert_fn, zip(dataset.index, dataset["smiles"])) + list(tqdm.tqdm(it, total=len(dataset))) diff --git a/scripts/generate_de_novo_samples.py b/scripts/generate_de_novo_samples.py new file mode 100644 index 0000000..2663d5e --- /dev/null +++ b/scripts/generate_de_novo_samples.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import argparse +import multiprocessing as mp + +import pandas as pd +import safe as sf +import torch +import tqdm +from admet_ai import ADMETModel +from rdkit import Chem, rdBase +from rdkit.Contrib.SA_Score import sascorer # type: ignore +from transformers import GPT2LMHeadModel + +rdBase.DisableLog("rdApp.*") + + +class Evaluator: + def __init__(self): + self.admet_ai = ADMETModel() + + def __call__(self, smiles_list: list[str]) -> pd.DataFrame: + mols = [Chem.MolFromSmiles(x) for x in smiles_list] + sa_scores = [sascorer.calculateScore(m) for m in mols] + max_ring = [max(map(len, m.GetRingInfo().AtomRings() or [[]])) for m in mols] + + admet = self.admet_ai.predict(smiles_list) + admet["SAScore"] = sa_scores + admet["CycleScore"] = [max(x - 6, 0) for x in max_ring] + admet["plogP"] = admet["logP"] - admet["SAScore"] - admet["CycleScore"] + return admet + + +def decode_smiles_from_valid_safe(safe: str) -> str | None: + smiles = sf.decode(safe, canonical=False, ignore_errors=True) + if smiles and Chem.MolFromSmiles(smiles) is not None: + return smiles + return None + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--num-samples", type=int, default=100000) + parser.add_argument("--batch-size", type=int, default=256) + parser.add_argument("--max-length", type=int, default=128) + parser.add_argument("--output", default="safe-de-novo-dataset.csv") + args = parser.parse_args() + + model = GPT2LMHeadModel.from_pretrained("datamol-io/safe-gpt") + model = model.bfloat16().cuda().eval().requires_grad_(False) + tokenizer = sf.SAFETokenizer.from_pretrained("datamol-io/safe-gpt").get_pretrained() + + inputs = torch.tensor([[tokenizer.bos_token_id]] * args.batch_size, device="cuda") + generated, tqdm_bar = [], tqdm.trange(args.num_samples, desc="Generation") + + with mp.Pool() as pool: + while len(generated) < args.num_samples: + outputs = model.generate( + inputs, + do_sample=True, + temperature=1.0, + max_length=args.max_length, + ) + outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) + smiles = pool.map(decode_smiles_from_valid_safe, outputs) + + for safe, smiles in zip(outputs, smiles): + if smiles: + generated.append({"safe": safe, "smiles": smiles}) + tqdm_bar.update() + + generated = pd.DataFrame(generated) + admet = Evaluator()(generated["smiles"]) + admet = pd.merge(generated, admet, left_on="smiles", right_index=True, how="outer") + admet.to_csv(args.output, index=False) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..13ab3b6 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,17 @@ +[flake8] +max-line-length = 88 +extend-ignore = E203, W503 + +[isort] +multi_line_output = 3 +include_trailing_comma = True +force_grid_wrap = 0 +use_parentheses = True +ensure_newline_before_comments = True +line_length = 88 +float_to_top = True +src_paths = */src, experimental/*/src + +[tool:pytest] +testpaths = + tests \ No newline at end of file diff --git a/src/dataset.py b/src/dataset.py new file mode 100644 index 0000000..b315f42 --- /dev/null +++ b/src/dataset.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import argparse +import os +from dataclasses import dataclass + +import fsspec +import numpy as np +import pandas as pd +import torch +from torch.utils.data import DataLoader, Dataset +from transformers import PreTrainedTokenizerBase + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +@dataclass +class MyDataset(Dataset): + texts: list[str] + labels: np.ndarray + tokenizer: PreTrainedTokenizerBase + max_length: int = 128 + + def __len__(self) -> int: + return len(self.texts) + + def __getitem__(self, i: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + encoding = self.tokenizer( + self.texts[i], + max_length=self.max_length, + padding="max_length", + truncation=True, + ) + return ( + torch.tensor(encoding["input_ids"], dtype=torch.int32), + torch.tensor(encoding["attention_mask"], dtype=torch.int32), + torch.tensor(self.labels[i], dtype=torch.float32), + ) + + +def create_train_dataloader( + args: argparse.Namespace, tokenizer: PreTrainedTokenizerBase +) -> DataLoader: + with fsspec.open(args.dataset) as fp: + data = pd.read_csv(fp) + + labels = [] + for target in args.target_columns: + name, direction = target.split(":")[:2] + labels.append(data[name] * (1 if direction == "max" else -1)) + + dataset = MyDataset( + texts=data["safe"], + labels=np.stack(labels, axis=1), + tokenizer=tokenizer, + max_length=args.max_length, + ) + return DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.num_workers, + drop_last=True, + generator=torch.Generator().manual_seed(args.shuffle_seed), + persistent_workers=True, + ) diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..0e8fd05 --- /dev/null +++ b/src/main.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +import argparse +import os +import warnings + +import flax.jax_utils +import jax +import numpy as np +import pandas as pd +import safe as sf +import tqdm +import wandb +from flax.jax_utils import unreplicate +from flax.training.common_utils import shard, shard_prng_key +from flax.training.train_state import TrainState +from rdkit import Chem +from transformers import FlaxGPT2LMHeadModel, PreTrainedTokenizerBase +from utils import AverageMeter, Evaluator + +from dataset import create_train_dataloader +from training import create_train_state, generation_step, training_step + +warnings.filterwarnings("ignore") + + +def evaluate( + args: argparse.Namespace, + step: int, + epoch: int, + state: TrainState, + evaluator: Evaluator, + tokenizer: PreTrainedTokenizerBase, + eval_batches: int, +): + tokens = np.array([[tokenizer.bos_token_id]] * args.batch_size, dtype=np.int32) + tokens = shard(tokens) + + generated, average_meter = [], AverageMeter() + for i in tqdm.trange(eval_batches, desc="Generation", dynamic_ncols=True): + rng = shard_prng_key(epoch + jax.random.PRNGKey(i)) + preds, metrics = generation_step(state, tokens, rng, args.max_length) + average_meter.update(**unreplicate(metrics)) + + for pred in jax.device_get(preds).reshape(-1, *preds.shape[2:]): + pred = pred[np.cumsum(pred == tokenizer.eos_token_id) < 1] + generated.append(tokenizer.decode(pred.tolist(), skip_special_tokens=True)) + + # Evaluate the generated mols with considering the format type. After calculating + # validity and uniqueness of the generated samples, the invalid and duplicated + # samples will be removed for ADMET-AI evaluation. + smiles = [sf.decode(x, canonical=False, ignore_errors=True) for x in generated] + mols = pd.DataFrame((generated, smiles)).T + mols.columns = ["safe", "smiles"] + + mols["Validity"] = mols["smiles"].notnull() + mols["Uniqueness"] = mols["smiles"].nunique() / mols["smiles"].notnull().sum() + mols = mols[mols.smiles.notnull()].drop_duplicates("smiles") + + admet = evaluator([x for x in smiles if x and Chem.MolFromSmiles(x)]) + admet = pd.merge(mols, admet, left_on="smiles", right_index=True, how="outer") + + # Compute the average metrics of the evaluated results and log them with the + # validation scores. + metrics = average_meter.summary(prefix="valid/") + for name in args.eval_metrics: + metrics[f"valid/{name}"] = admet[name].mean() + metrics["epoch"] = epoch + wandb.log(metrics, step) + + +def main(args: argparse.Namespace): + evaluator = Evaluator() + model = FlaxGPT2LMHeadModel.from_pretrained("datamol-io/safe-gpt", from_pt=True) + tokenizer = sf.SAFETokenizer.from_pretrained("datamol-io/safe-gpt").get_pretrained() + dataloader = create_train_dataloader(args, tokenizer) + + state = create_train_state(args, model, steps_per_epoch=len(dataloader)) + state = flax.jax_utils.replicate(state) + + wandb.init(name=args.name, project=args.project, config=args) + average_meter, step = AverageMeter(use_latest=["learning_rate"]), 0 + + # Before training, we will evaluate the initial performance of the model. + evaluate( + args=args, + step=0, + epoch=0, + state=state, + evaluator=evaluator, + tokenizer=tokenizer, + eval_batches=args.eval_batches, + ) + + for epoch in range(args.epochs): + for batch in tqdm.tqdm(dataloader, desc=f"Epoch {epoch}", dynamic_ncols=True): + state, metrics = training_step(state, shard(jax.tree.map(np.array, batch))) + average_meter.update(**unreplicate(metrics)) + step += 1 + + if args.log_interval > 0 and step % args.log_interval == 0: + metrics = average_meter.summary(prefix="train/") + metrics["epoch"] = step / len(dataloader) + wandb.log(metrics, step) + + if ( + args.eval_interval > 0 + and (epoch + 1) % args.eval_interval == 0 + or epoch == args.epochs - 1 + ): + evaluate( + args=args, + step=step, + epoch=epoch + 1, + state=state, + evaluator=evaluator, + tokenizer=tokenizer, + eval_batches=args.eval_batches, + ) + + os.makedirs(args.output_dir, exist_ok=True) + model.params = unreplicate(state.params["act"]) + model.save_pretrained(os.path.join(args.output_dir, args.name)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dataset") + parser.add_argument("--batch-size", type=int, default=128) + parser.add_argument("--num-workers", type=int, default=16) + + parser.add_argument("--target-columns", nargs="+") + parser.add_argument("--max-length", type=int, default=128) + parser.add_argument("--penalty-beta", type=float, default=0.1) + parser.add_argument("--eval-metrics", nargs="+", default=["Validity"]) + + parser.add_argument("--use-moco", action="store_true", default=False) + parser.add_argument("--jacmom", type=float, default=0.99) + parser.add_argument("--lammom", type=float, default=0.5) + parser.add_argument("--lamreg", type=float, default=0.1) + + parser.add_argument("--learning-rate", type=float, default=5e-5) + parser.add_argument("--weight-decay", type=float, default=0.01) + parser.add_argument("--adam-b1", type=float, default=0.9) + parser.add_argument("--adam-b2", type=float, default=0.999) + parser.add_argument("--adam-eps", type=float, default=1e-8) + parser.add_argument("--clip-grad", type=float, default=0.0) + parser.add_argument("--warmup-ratio", type=float, default=0.1) + + parser.add_argument("--epochs", type=int, default=10) + parser.add_argument("--log-interval", type=int, default=50) + parser.add_argument("--eval-interval", type=int, default=1) + parser.add_argument("--eval-batches", type=int, default=4) + parser.add_argument("--split-seed", type=int, default=0) + parser.add_argument("--shuffle-seed", type=int, default=0) + + parser.add_argument("--project") + parser.add_argument("--name") + parser.add_argument("--ipaddr") + parser.add_argument("--hostname") + parser.add_argument("--output-dir", default="./") + main(parser.parse_args()) diff --git a/src/training.py b/src/training.py new file mode 100644 index 0000000..ef062f9 --- /dev/null +++ b/src/training.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +import argparse +from functools import partial + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +import optax +from chex import Array, ArrayTree, PRNGKey +from flax.training import train_state +from transformers import FlaxPreTrainedModel +from transformers.models.gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModule + +from utils import FlaxGPT2LMHeadModelWrapper + + +class TrainState(train_state.TrainState): + jacbuf: Array + jacmom: Array + lambuf: Array | None + lammom: Array + lamreg: Array + lampref: Array + + +class TrainModule(nn.Module): + act: FlaxGPT2LMHeadModule + ref: FlaxGPT2LMHeadModule + penalty_beta: float = 0.1 + + def _compute_logprobs(self, model: nn.Module, tokens: Array, mask: Array) -> Array: + logits = model(tokens, mask, jnp.cumsum(mask, axis=-1) - 1).logits + logprobs = nn.log_softmax(logits[:, :-1, :].astype(jnp.float32)) + return jnp.take_along_axis(logprobs, tokens[:, 1:, None], axis=-1)[..., 0] + + def __call__(self, tokens: Array, mask: Array, labels: Array) -> ArrayTree: + logp_act = self._compute_logprobs(self.act, tokens, mask) + logp_ref = self._compute_logprobs(self.ref, tokens, mask) + logits = (mask[:, 1:] * (logp_act - logp_ref)).sum(-1) + + # Gather the difference of the log probability between policy and reference + # models and the ground truth reward of each sequence across all devices. + logits, labels = jax.lax.all_gather((logits, labels), "batch", tiled=True) + logits = logits[:, None] - logits[None, :] + sign = jnp.sign(labels[:, None] - labels[None, :]) + + loss = -nn.log_sigmoid(sign * self.penalty_beta * logits[:, :, None]) + loss = loss.mean((0, 1)) + + accuracy = jnp.sign(logits[:, :, None]) == sign + accuracy = (accuracy * jnp.abs(sign)).sum() / jnp.abs(sign).sum() + return {"loss": loss, "accuracy": accuracy} + + def generate( + self, tokens: Array, sample_rng: PRNGKey, max_length: int + ) -> tuple[Array, ArrayTree]: + outputs = FlaxGPT2LMHeadModelWrapper(self.act.config, self.act).generate( + tokens, + prng_key=sample_rng, + params=self.act.variables["params"], + do_sample=True, + temperature=1.0, + max_length=max_length, + ) + outputs = outputs.sequences + mask = jnp.cumsum(outputs == self.act.config.eos_token_id, axis=-1) < 1 + + logp_act = self._compute_logprobs(self.act, outputs, mask) + logp_ref = self._compute_logprobs(self.ref, outputs, mask) + logp_diff = mask[:, 1:] * (logp_act - logp_ref) + return outputs, {"kld": logp_diff[tokens.shape[1] - 1 :].sum(-1).mean()} + + +def get_gradient_slice(grads: ArrayTree, is_jacobian: bool = False) -> Array: + last_layer_idx = max(map(int, grads["act"]["transformer"]["h"])) + last_layer_grads = grads["act"]["transformer"]["h"][str(last_layer_idx)] + + arrays = [ + grads["act"]["transformer"]["ln_f"]["scale"], + last_layer_grads["ln_1"]["scale"], + last_layer_grads["ln_2"]["scale"], + last_layer_grads["attn"]["c_attn"]["kernel"], + last_layer_grads["attn"]["c_proj"]["kernel"], + last_layer_grads["mlp"]["c_fc"]["kernel"], + last_layer_grads["mlp"]["c_proj"]["kernel"], + ] + flatten_arrays = [ + array.reshape((array.shape[0], -1) if is_jacobian else (-1,)) + for array in arrays + ] + return jnp.concatenate(flatten_arrays, axis=-1) + + +@partial(jax.pmap, axis_name="batch", donate_argnums=0) +def training_step(state: TrainState, batch: ArrayTree) -> tuple[TrainState, ArrayTree]: + def jacobian_fn(params: ArrayTree) -> ArrayTree: + return state.apply_fn({"params": params}, *batch)["loss"] + + def lambda_optimize_fn(logits: Array) -> Array: + grads = (nn.softmax(logits) + state.lamreg * state.lampref) @ jacobian + return 0.5 * jnp.square(grads).sum() + + # Compute task-wise gradients (a.k.a jacobian matrix) and average them across the + # devices using `jax.lax.pmean` since this function is wrapped by `jax.pmap`. + jacobian = jax.jacrev(jacobian_fn)(state.params) + jacobian = get_gradient_slice(jacobian, is_jacobian=True) + jacobian = jacobian.reshape(jacobian.shape[0], -1) + jacobian = jax.lax.pmean(jacobian, axis_name="batch") + + # Apply EMA to the jacobian buffer to estimate global expectation of the gradients. + # Note that the actual jacobian will be corrected by momentum. Note also that many + # implementations normalize gradients and consider directions only. + jacbuf = state.jacmom * state.jacbuf + (1 - state.jacmom) * jacobian + jacobian = jacbuf / (1 - state.jacmom ** (state.step + 1)) + jacobian = jacobian / jnp.linalg.norm(jacobian, axis=-1, keepdims=True).mean() + # jacobian = jacobian / jnp.linalg.norm(jacobian, axis=-1, keepdims=True) + + # Update the logits of the lambda vector by using manual SGD. The jacobian is + # already be debiased by EMA and we use a single loop instead of multiple descent + # steps for the lambda logits. Note that we use softmax to the logits so that we + # remove the probability simplex constraint. + if state.lambuf is not None: + lamgrads = jax.grad(lambda_optimize_fn)(state.lambuf) + lambuf = state.lambuf - state.lammom * lamgrads + + weights = nn.softmax(lambuf) + state.lamreg * state.lampref + weights = weights / (1 + state.lamreg) + else: + lambuf, weights = None, jnp.ones(jacobian.shape[0]) / jacobian.shape[0] + + def weighted_loss_fn(params: ArrayTree) -> ArrayTree: + metrics = state.apply_fn({"params": params}, *batch) + metrics["loss"] = weights @ metrics["loss"] + return metrics["loss"], metrics + + metrics, grads = jax.value_and_grad(weighted_loss_fn, has_aux=True)(state.params) + metrics, grads = jax.lax.pmean((metrics[1], grads), axis_name="batch") + + metrics |= {f"weight{i}": j for i, j in enumerate(weights)} + # metrics |= {f"logit{i}": j for i, j in enumerate(lambuf)} + # metrics |= {f"grad{i}": j for i, j in enumerate(lamgrads)} + state = state.apply_gradients(grads=grads, jacbuf=jacbuf, lambuf=lambuf) + return state, metrics | state.opt_state.hyperparams + + +@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3,)) +def generation_step( + state: TrainState, tokens: Array, sample_rng: PRNGKey, max_length: int +) -> tuple[Array, ArrayTree]: + outputs, metrics = state.apply_fn( + {"params": state.params}, tokens, sample_rng, max_length, method="generate" + ) + return outputs, jax.lax.pmean(metrics, axis_name="batch") + + +def create_train_state( + args: argparse.Namespace, model: FlaxPreTrainedModel, steps_per_epoch: int +) -> TrainState: + module = TrainModule( + act=FlaxGPT2LMHeadModule(model.config), + ref=FlaxGPT2LMHeadModule(model.config), + penalty_beta=args.penalty_beta, + ) + params = {"act": model.params, "ref": jax.tree.map(jnp.copy, model.params)} + + jacbuf = get_gradient_slice(params, is_jacobian=False) + jacbuf = jnp.zeros((len(args.target_columns), jacbuf.size)) + lambuf = jnp.zeros(len(args.target_columns)) + + lampref = np.array([float(x.split(":")[2]) for x in args.target_columns]) + lampref = lampref / (lampref.sum() + 1e-10) + + # Create learning rate scheduler and optimizer with gradient clipping. The learning + # rate will be recorded at `hyperparams` by `optax.inject_hyperparameters`. + @partial(optax.inject_hyperparams, hyperparam_dtype=jnp.float32) + def create_optimizer_fn( + learning_rate: optax.Schedule, + ) -> optax.GradientTransformation: + tx = optax.adamw( + learning_rate=learning_rate, + b1=args.adam_b1, + b2=args.adam_b2, + eps=args.adam_eps, + weight_decay=args.weight_decay, + mask=partial(jax.tree.map, lambda x: x.ndim > 1), + ) + if args.clip_grad > 0: + tx = optax.chain(optax.clip_by_global_norm(args.clip_grad), tx) + return optax.multi_transform( + {"act": tx, "ref": optax.set_to_zero()}, + partial(jax.tree_util.tree_map_with_path, lambda path, _: path[0].key), + ) + + learning_rate = optax.warmup_cosine_decay_schedule( + init_value=0, + peak_value=args.learning_rate, + decay_steps=(total_steps := args.epochs * steps_per_epoch), + warmup_steps=int(args.warmup_ratio * total_steps), + end_value=0, + ) + return TrainState.create( + apply_fn=module.apply, + params=params, + tx=create_optimizer_fn(learning_rate), + jacbuf=jacbuf, + jacmom=args.jacmom, + lambuf=lambuf if args.use_moco else None, + lammom=args.lammom, + lamreg=args.lamreg, + lampref=lampref.astype(np.float32), + ) diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..81422e1 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from collections import defaultdict + +import chemfunc +import numpy as np +import pandas as pd +from admet_ai import ADMETModel +from rdkit import Chem, rdBase +from rdkit.Contrib.SA_Score import sascorer # type: ignore +from transformers import FlaxGPT2LMHeadModel, FlaxPreTrainedModel, GPT2Config +from transformers.models.gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModule + +rdBase.DisableLog("rdApp.*") + + +class FlaxGPT2LMHeadModelWrapper(FlaxGPT2LMHeadModel): + def __init__(self, config: GPT2Config, module: FlaxGPT2LMHeadModule): + FlaxPreTrainedModel.__init__(self, config, module, _do_init=False) + + +class AverageMeter: + def __init__(self, use_latest: list[str] = []): + self.buffer = defaultdict(list) + self.use_latest = use_latest + + def update(self, **kwargs: float): + for k, v in kwargs.items(): + self.buffer[k].append(v) + + def summary(self, prefix: str = "") -> dict[str, float]: + buffer = {k: np.array(v) for k, v in self.buffer.items()} + self.buffer.clear() + + return { + f"{prefix}{k}": v[-1] if k in self.use_latest else np.mean(v) + for k, v in buffer.items() + } + + +class Evaluator: + def __init__(self): + self.admet_ai = ADMETModel() + + def __call__(self, smiles_list: list[str]) -> pd.DataFrame: + # Calculate additional scores from SMILES. + mols = [Chem.MolFromSmiles(x) for x in smiles_list] + sa_scores = [sascorer.calculateScore(m) for m in mols] + max_ring = [max(map(len, m.GetRingInfo().AtomRings() or [[]])) for m in mols] + + # Calculate interval diversity from morgan fingerprints. + morgan_fp = np.stack([chemfunc.compute_morgan_fingerprint(m) for m in mols]) + dot, norm = morgan_fp @ morgan_fp.T, morgan_fp.sum(-1, keepdims=True) + tanimoto = dot / (norm + norm.T - dot) + intdiv = 1 - (tanimoto.sum(-1) - 1) / (tanimoto.shape[1] - 1) + + # Use ADMET-AI model to predict ADMET from SMILES. + admet = self.admet_ai.predict(smiles_list) + admet["SAScore"] = sa_scores + admet["CycleScore"] = [max(x - 6, 0) for x in max_ring] + admet["plogP"] = admet["logP"] - admet["SAScore"] - admet["CycleScore"] + admet["IntDiv"] = intdiv + return admet