Skip to content

Commit

Permalink
add inference scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
ndnng committed Apr 19, 2024
1 parent 4ea7e92 commit 3c3d80f
Show file tree
Hide file tree
Showing 15 changed files with 1,147 additions and 523 deletions.
111 changes: 77 additions & 34 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ conda activate ProteinReDiff

Download model parameters:
```bash
gdown --fuzzy --folder https://drive.google.com/drive/u/1/folders/1AAJ4P5EmQtwle9_eSeNMcF-KMWObksxZ
gdown --fuzzy --folder https://drive.google.com/drive/folders/1rPlzMUPgKLFd_Krk8cGqhEeitWByPOMn?usp=sharing
```

Additionally, TMalign is required to align generated structures.
Expand All @@ -47,50 +47,70 @@ export PATH="/path/to/TMalign:$PATH"
```

## Sample generation
Generate complex structures with the protein structure-free model (ProteinReDiff):

Generate single complex structure using ProteinReDiff:

```bash
python generate.py \
--ckpt_path "checkpoints/ProteinReDiff_v1.ckpt" \
--output_dir "workdir/generate/example_ProteinReDiff" \
--ckpt_path "checkpoints/PRD_ver1.ckpt"\
--output_dir "workdir/inference/example_ProteinReDiff" \
--protein "LSEQLKHCNGILKELLSKKHAAYAWPFYKPVDASALGLHDYHDIIKHPMDLSTVKRKMENRDYRDAQEFAADVRLMFSNCYKYNPPDHDVVAMARKLQDVFEFRYAKMPD" \
--ligand "Cc1ccc2c(c1c3cc(cc4c3nc([nH]4)C5CC5)c6c(noc6C)C)cccn2" \
--num_samples 8
--num_samples 3 \
--num_steps 1000
```

Alternatively, the protein structure-dependent model (ProteinReDiff+S) can be used:
The sequence can be masked prior to input using `X` token (`mask_prob` is the fraction of input protein residues to be masked from 0.0 to 1.0):
```bash
wget https://files.rcsb.org/download/6MOA.pdb
python generate.py \
--ckpt_path "checkpoints/ProteinReDiffS_v1.ckpt" \
--output_dir "workdir/generate/example_ProteinReDiffS" \
--protein "6MOA.pdb" \
--ckpt_path "checkpoints/PRD_ver1.ckpt"\
--output_dir "workdir/inference/example_ProteinReDiff" \
--protein "LSEQXXXXNGILKELLSKXXXXYAWPFYKPVDASALGLHDYHDIIKXXXXLSTVKRKMENRDYRDXXXXAADVRLMFSNCYKYNPPDHDVVAMARKLQDVFEFRYAKMPD" \
--ligand "Cc1ccc2c(c1c3cc(cc4c3nc([nH]4)C5CC5)c6c(noc6C)C)cccn2" \
--num_samples 8
--num_samples 4 \
--num_steps 1000 \
--mask_prob 0.0
```
Note that an input protein structure must be given as a PDB file in this case.

Besides, you can specify a reference protein structure to be used for the alignment of results:
Generate structure ensembles without ligand (use the dummy token `*`):
```bash
python generate.py \
--ckpt_path "checkpoints/ProteinReDiff_v1.ckpt" \
--output_dir "workdir/generate/example_ProteinReDiff_ref" \
--ckpt_path "checkpoints/PRD_ver1.ckpt"\
--output_dir "workdir/inference/example_ProteinReDiff" \
--protein "LSEQLKHCNGILKELLSKKHAAYAWPFYKPVDASALGLHDYHDIIKHPMDLSTVKRKMENRDYRDAQEFAADVRLMFSNCYKYNPPDHDVVAMARKLQDVFEFRYAKMPD" \
--ligand "Cc1ccc2c(c1c3cc(cc4c3nc([nH]4)C5CC5)c6c(noc6C)C)cccn2" \
--num_samples 8 \
--ref_path "6MOA.pdb"
--ligand "*" \
--num_samples 3 \
--num_steps 1000
```
This is used only for alignment and does not affect the generation process itself.

The argument num_steps can be modified from the default of 64 to reduce execution time:
Generate multiple complex structures:
```bash
python generate.py \
--ckpt_path "checkpoints/ProteinReDiff_v1.ckpt" \
--output_dir "workdir/generate/example_ProteinReDiff_fast" \
--protein "LSEQLKHCNGILKELLSKKHAAYAWPFYKPVDASALGLHDYHDIIKHPMDLSTVKRKMENRDYRDAQEFAADVRLMFSNCYKYNPPDHDVVAMARKLQDVFEFRYAKMPD" \
--ligand "Cc1ccc2c(c1c3cc(cc4c3nc([nH]4)C5CC5)c6c(noc6C)C)cccn2" \
--num_samples 8 \
--ref_path "6MOA.pdb" \
--num_steps 24
python -m scripts.predict_batch_strc_msk_inp \
--ckpt_path "checkpoints/PRD_ver1.ckpt" \
--output_dir "workdir/inference/example_ProteinReDiff" \
--fasta "./scripts/test_sequences_from_pdb.fasta" \
--ligand_file './scripts/scripts.smiles' \
--accelerator "gpu"\
--num_gpus 1 \
--batch_size 1 \
--num_samples 1 \
--mask_prob 0.5 \
--num_steps 1000
```

Alternatively, generate multiple samples per sequences only (`mask_prob` can be adjusted to increase diversity, best sequences are masked below 0.5):
```bash
python -m scripts.predict_batch_seq_msk_inp \
--ckpt_path "checkpoints/PRD_ver1.ckpt" \
--output_dir "workdir/inference/example_ProteinReDiff" \
--fasta "./scripts/test_sequences_from_pdb.fasta" \
--ligand_file './scripts/test_pdb.smiles' \
--accelerator "gpu"\
--num_gpus 1 \
--batch_size 1 \
--num_samples 10 \
--mask_prob 0.3 \
--num_steps 1000
```

## Training
Expand All @@ -110,7 +130,6 @@ python train.py \
--num_workers 8 \
--batch_size 1 \
--accumulate_grad_batches 8 \
--no_cb_distogram \
--save_dir "workdir/train/example_ProteinReDiff" \
--single_dim 256 \
--pair_dim 32 \
Expand All @@ -121,13 +140,37 @@ where the no_cb_distogram argument makes the model protein structure-free.
Please modify the batch_size and accumulate_grad_batches arguments according to your machine(s).

Default values can be used to reproduce the settings used in our paper:
```

```bash
python train.py \
--num_workers 8 \
--batch_size 3 \
--training_mode \
--num_gpus 1\
--num_workers 30 \
--batch_size 2 \
--accumulate_grad_batches 10 \
--save_dir "workdir/train/example_ProteinReDiff" \
--single_dim 512 \
--mask_prob 0.15 \
--pair_dim 64 \
--num_steps 2000 \
--num_blocks 4
```
Due to the limitation of runtime on GPUs, we prepared a `train_from_ckpt.py` script to further train on finished epoch:

```bash
python train_from_ckpt.py \
--training_mode \
--num_gpus 1\
--num_workers 30 \
--batch_size 2 \
--accumulate_grad_batches 8 \
--no_cb_distogram \
--save_dir "workdir/train/reproduce_ProteinReDiff"
--save_dir "$save_dir" \
--single_dim 512 \
--mask_prob 0.15 \
--pair_dim 64 \
--num_steps 1000 \
--num_blocks 4 \
--trained_ckpt "checkpoints/PRD_ver1.ckpt"
```

## Acknowledgements
Expand Down
6 changes: 6 additions & 0 deletions scripts/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
### Inference Scripts:

- predict_batch_strc_msk_inp.py: predict batch of complex structures.
- predict_batch_seq_msk_inp.py: predict batch of sequences only without TM alignment comparison.


File renamed without changes.
File renamed without changes.
File renamed without changes.
131 changes: 103 additions & 28 deletions test/predict_batch_seq.py → scripts/predict_batch_seq_msk_inp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from argparse import ArgumentParser
from operator import itemgetter
from pathlib import Path
from typing import Iterable, List, Union, Tuple
from typing import Iterable, List, Union, Tuple, Any

import numpy as np
import random
import pytorch_lightning as pl
import torch
from rdkit import Chem
Expand All @@ -25,22 +26,62 @@
proteins_to_pdb_file,
)
from dpl.tmalign import run_tmalign

torch.multiprocessing.set_start_method('fork')

RESIDUE_TYPES_MASK = RESIDUE_TYPES + ["<mask>"]
def compute_residue_esm(protein: Protein) -> torch.Tensor:
esm_model, esm_alphabet = torch.hub.load(
"facebookresearch/esm:main", "esm2_t33_650M_UR50D"
)
esm_model.cuda().eval()
esm_batch_converter = esm_alphabet.get_batch_converter()



esm_model = None
esm_batch_converter = None

def load_esm_model(accelerator):
global esm_model, esm_batch_converter
if esm_model is None or esm_batch_converter is None:
esm_model, esm_alphabet = torch.hub.load(
"facebookresearch/esm:main", "esm2_t33_650M_UR50D"
)

# esm_model.cuda().eval()
if accelerator == "gpu":
esm_model.cuda().eval()
else:
esm_model.eval()
esm_batch_converter = esm_alphabet.get_batch_converter()

# esm_model, esm_alphabet = torch.hub.load(
# "facebookresearch/esm:main", "esm2_t33_650M_UR50D"
# )
# esm_model.cuda().eval()
# esm_batch_converter = esm_alphabet.get_batch_converter()

def compute_residue_esm(protein: Protein, accelerator: str) -> torch.Tensor:
# esm_model, esm_alphabet = torch.hub.load(
# "facebookresearch/esm:main", "esm2_t33_650M_UR50D"
# )
# esm_model.cuda().eval()
# esm_batch_converter = esm_alphabet.get_batch_converter()
global esm_model, esm_batch_converter
# if esm_model is None or esm_batch_converter is None:
# esm_model, esm_alphabet = torch.hub.load(
# "facebookresearch/esm:main", "esm2_t33_650M_UR50D"
# )
# esm_model.cuda().eval()
# esm_batch_converter = esm_alphabet.get_batch_converter()
load_esm_model(accelerator)

data = []
for chain, _ in itertools.groupby(protein.chain_index):
sequence = "".join(
[RESIDUE_TYPES_MASK[aa] for aa in protein.aatype[protein.chain_index == chain]]
)
data.append(("", sequence))
batch_tokens = esm_batch_converter(data)[2].cuda()
# batch_tokens = esm_batch_converter(data)[2].cuda()
if accelerator == "gpu":
batch_tokens = esm_batch_converter(data)[2].cuda()
else:
batch_tokens = esm_batch_converter(data)[2]
with torch.inference_mode():
results = esm_model(batch_tokens, repr_layers=[esm_model.num_layers])
token_representations = results["representations"][esm_model.num_layers].cpu()
Expand All @@ -66,6 +107,24 @@ def proteins_from_fasta(fasta_file: Union[str, Path]):

return proteins, names

def proteins_from_fasta_with_mask(fasta_file: Union[str, Path], mask_percent: float = 0.0):
names = []
proteins = []
sequences = []
with open(fasta_file, "r") as f:
for line in f:
if line.startswith(">"):
name = line.lstrip(">").rstrip("\n").replace(" ","_")
names.append(name)
elif not line in ['\n', '\r\n']:
sequence = line.rstrip("\n")
sequence = mask_sequence_by_percent(sequence, mask_percent)
protein = protein_from_sequence(sequence)
proteins.append(protein)
sequences.append(sequence)

return proteins, names, sequences

def parse_ligands(ligand_input: Union[str, Path, list]):
ligands = []
if isinstance(ligand_input, list):
Expand Down Expand Up @@ -110,9 +169,17 @@ def update_seq(
protein = dataclasses.replace(protein, aatype = aatype)
return protein

def mask_sequence_by_percent(seq, percentage=0.2):
aa_to_replace = random.sample(range(len(seq)), int(len(seq)*percentage))

output_aa = [char if idx not in aa_to_replace else 'X' for idx, char in enumerate(seq)]
masked_seq = ''.join(output_aa)

return masked_seq

def main(args):
pl.seed_everything(args.seed, workers=True)
pl.seed_everything(np.random.randint(999999999), workers=True)

# Check if the directory exists
if os.path.exists(args.output_dir):
# Remove the existing directory
Expand All @@ -123,17 +190,20 @@ def main(args):
model = ProteinReDiffModel.load_from_checkpoint(
args.ckpt_path, num_steps=args.num_steps
)
## (NN)
model.training_mode = False
args.num_gpus = 1
model.mask_prob = args.mask_prob
## (NN)

# Inputs
proteins, names = proteins_from_fasta(args.fasta)

proteins, names, masked_sequences = proteins_from_fasta_with_mask(args.fasta, args.mask_prob)

with open(args.output_dir / "masked_sequences.fasta", "w") as f:
for i, (name, seq) in enumerate(zip(names, masked_sequences)):
f.write(">{}_sample_{}\n".format(name,i%args.num_samples))
f.write("{}\n".format(seq))

if args.ligand_file is None:
ligand_input = ["*"]*args.num_samples*len(names)
ligand_input = ["*"]*len(names)

ligands = parse_ligands(ligand_input)
else:
Expand All @@ -145,35 +215,38 @@ def main(args):
# warnings.warn("Too many atoms (> 400). May take a long time for sample generation.")

datas = []
for protein, ligand in zip(proteins, ligands):
for name, protein, ligand in zip(names,proteins, ligands):
data = {
**ligand_to_data(ligand),
**protein_to_data(protein, residue_esm=compute_residue_esm(protein)),
**protein_to_data(protein, residue_esm=compute_residue_esm(protein, args.accelerator)),
}
datas.extend([data]*args.num_samples)



# Generate samples
trainer = pl.Trainer.from_argparse_args(
args,
accelerator="auto",
gpus = args.num_gpus,
default_root_dir=args.output_dir,
max_epochs=-1,
)

trainer = pl.Trainer(
accelerator=args.accelerator,
devices=args.num_gpus,
default_root_dir=args.output_dir,
max_epochs=-1,
strategy='ddp'

)
results = trainer.predict(
model,
dataloaders=DataLoader(
InferenceDataset(datas, args.num_samples * len(names)),
batch_size=args.batch_size,
num_workers=args.num_workers,
collate_fn=collate_fn,
),
)
)

torch.save(results,"results.pt")
positions = [p[0] for p in results] ## (NN)


#torch.save(results,"results.pt")
# positions = [p[0] for p in results] ## (NN)
probabilities = [s[1] for s in results] ## (NN)

# positions = torch.cat(positions, dim=0).detach().cpu().numpy()
Expand Down Expand Up @@ -226,8 +299,10 @@ def main(args):
if __name__ == "__main__":
parser = ArgumentParser()

parser.add_argument("--seed", type=int, default=1234)
# parser.add_argument("--seed", type=int, default=1234)
parser.add_argument("--accelerator", type=str, default="gpu")
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--num_gpus", type=int, default=1)
parser.add_argument("--num_workers", type=int, default=torch.get_num_threads())
parser.add_argument("--num_steps", type=int, default=64)
parser.add_argument("--mask_prob", type=float, default=0.3)
Expand Down
Loading

0 comments on commit 3c3d80f

Please sign in to comment.