diff --git a/ProteinReDiff/__pycache__/data.cpython-39.pyc b/ProteinReDiff/__pycache__/data.cpython-39.pyc index 5cd0978..fbf2292 100644 Binary files a/ProteinReDiff/__pycache__/data.cpython-39.pyc and b/ProteinReDiff/__pycache__/data.cpython-39.pyc differ diff --git a/ProteinReDiff/__pycache__/difffusion.cpython-39.pyc b/ProteinReDiff/__pycache__/difffusion.cpython-39.pyc index da39752..63985b3 100644 Binary files a/ProteinReDiff/__pycache__/difffusion.cpython-39.pyc and b/ProteinReDiff/__pycache__/difffusion.cpython-39.pyc differ diff --git a/ProteinReDiff/__pycache__/features.cpython-39.pyc b/ProteinReDiff/__pycache__/features.cpython-39.pyc index 13aeb1a..adc084c 100644 Binary files a/ProteinReDiff/__pycache__/features.cpython-39.pyc and b/ProteinReDiff/__pycache__/features.cpython-39.pyc differ diff --git a/ProteinReDiff/__pycache__/mask_utils.cpython-39.pyc b/ProteinReDiff/__pycache__/mask_utils.cpython-39.pyc index 6d7ad1e..7baae5e 100644 Binary files a/ProteinReDiff/__pycache__/mask_utils.cpython-39.pyc and b/ProteinReDiff/__pycache__/mask_utils.cpython-39.pyc differ diff --git a/ProteinReDiff/__pycache__/model.cpython-39.pyc b/ProteinReDiff/__pycache__/model.cpython-39.pyc index baf8add..47d3dc3 100644 Binary files a/ProteinReDiff/__pycache__/model.cpython-39.pyc and b/ProteinReDiff/__pycache__/model.cpython-39.pyc differ diff --git a/ProteinReDiff/__pycache__/modules.cpython-39.pyc b/ProteinReDiff/__pycache__/modules.cpython-39.pyc index 723001b..aabb68d 100644 Binary files a/ProteinReDiff/__pycache__/modules.cpython-39.pyc and b/ProteinReDiff/__pycache__/modules.cpython-39.pyc differ diff --git a/ProteinReDiff/__pycache__/mol.cpython-39.pyc b/ProteinReDiff/__pycache__/mol.cpython-39.pyc index e87dd9a..7f14001 100644 Binary files a/ProteinReDiff/__pycache__/mol.cpython-39.pyc and b/ProteinReDiff/__pycache__/mol.cpython-39.pyc differ diff --git a/ProteinReDiff/__pycache__/protein.cpython-39.pyc b/ProteinReDiff/__pycache__/protein.cpython-39.pyc index ae24225..05e7909 100644 Binary files a/ProteinReDiff/__pycache__/protein.cpython-39.pyc and b/ProteinReDiff/__pycache__/protein.cpython-39.pyc differ diff --git a/ProteinReDiff/__pycache__/tmalign.cpython-39.pyc b/ProteinReDiff/__pycache__/tmalign.cpython-39.pyc index d538d40..50faa5a 100644 Binary files a/ProteinReDiff/__pycache__/tmalign.cpython-39.pyc and b/ProteinReDiff/__pycache__/tmalign.cpython-39.pyc differ diff --git a/ProteinReDiff/__pycache__/utils.cpython-39.pyc b/ProteinReDiff/__pycache__/utils.cpython-39.pyc index 604c3b9..1fafe55 100644 Binary files a/ProteinReDiff/__pycache__/utils.cpython-39.pyc and b/ProteinReDiff/__pycache__/utils.cpython-39.pyc differ diff --git a/ProteinReDiff/data.py b/ProteinReDiff/data.py index 5d6368e..efe3156 100644 --- a/ProteinReDiff/data.py +++ b/ProteinReDiff/data.py @@ -167,7 +167,7 @@ def __getitem__(self, index: int) -> Mapping[str, Any]: return self.data[index] -class PDBbindDataset(Dataset): +class PDBDataset(Dataset): def __init__(self, root_dir: Union[str, Path], pdb_ids: Sequence[str]): super().__init__() if isinstance(root_dir, str): @@ -203,7 +203,7 @@ def get_stream(self): def __iter__(self): return self.get_stream() -class PDBbindDataModule(pl.LightningDataModule): +class PDBDataModule(pl.LightningDataModule): def __init__( self, data_dir: Union[str, Path] = "data", @@ -214,7 +214,7 @@ def __init__( if isinstance(data_dir, str): data_dir = Path(data_dir) self.data_dir = data_dir - self.cache_dir = data_dir / "PDBBind_processed_cache" + self.cache_dir = data_dir / "PDB_processed_cache" self.batch_size = batch_size self.num_workers = num_workers @@ -231,7 +231,7 @@ def setup(self, stage: Optional[str] = None) -> None: def train_dataloader(self) -> DataLoader: return DataLoader( - PDBbindDataset(self.cache_dir, self.train_pdb_ids), + PDBDataset(self.cache_dir, self.train_pdb_ids), batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, @@ -242,7 +242,7 @@ def train_dataloader(self) -> DataLoader: def val_dataloader(self) -> DataLoader: return DataLoader( - PDBbindDataset(self.cache_dir, self.val_pdb_ids), + PDBDataset(self.cache_dir, self.val_pdb_ids), batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=collate_fn, @@ -251,7 +251,7 @@ def val_dataloader(self) -> DataLoader: def test_dataloader(self) -> DataLoader: return DataLoader( - PDBbindDataset(self.cache_dir, self.test_pdb_ids), + PDBDataset(self.cache_dir, self.test_pdb_ids), batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=collate_fn, diff --git a/ProteinReDiff/models/__pycache__/AF2_modules.cpython-39.pyc b/ProteinReDiff/models/__pycache__/AF2_modules.cpython-39.pyc index 985840b..f5ec957 100644 Binary files a/ProteinReDiff/models/__pycache__/AF2_modules.cpython-39.pyc and b/ProteinReDiff/models/__pycache__/AF2_modules.cpython-39.pyc differ diff --git a/README.md b/README.md index 442d158..ca0d0d5 100644 --- a/README.md +++ b/README.md @@ -121,7 +121,7 @@ python -m scripts.predict_batch_seq_msk_inp \ Download the PDBbind dataset from https://zenodo.org/record/6408497 and unzip it. -Move the resulting PDBBind_processed directory to data/. +Move the resulting PDBBind_processed directory to ./data/. Preprocess the dataset: ```bash @@ -140,7 +140,7 @@ python train.py \ --num_blocks 4 ``` -Please modify the batch_size and accumulate_grad_batches arguments according to your machine(s). +Please modify the batch_size, gpus, and accumulate_grad_batches arguments according to your machine(s). Also, use the flag data_dir for directory containing your weights (ie. "./data") and save_dir for directory to save training log files. Default values can be used to reproduce the settings used in our paper: diff --git a/generate.py b/generate.py index c2cffda..a0534f3 100644 --- a/generate.py +++ b/generate.py @@ -1,3 +1,13 @@ +""" +Adapted from Nakata, S., Mori, Y. & Tanaka, S. +End-to-end protein–ligand complex structure generation with diffusion-based generative models. +BMC Bioinformatics 24, 233 (2023). +https://doi.org/10.1186/s12859-023-05354-5 + +Repository: https://github.com/shuyana/DiffusionProteinLigand + +""" + import dataclasses import itertools import warnings @@ -18,6 +28,7 @@ from ProteinReDiff.mol import get_mol_positions, mol_from_file, update_mol_positions from ProteinReDiff.protein import ( RESIDUE_TYPES, + RESIDUE_TYPE_INDEX, Protein, protein_from_pdb_file, protein_from_sequence, @@ -25,7 +36,7 @@ ) from ProteinReDiff.tmalign import run_tmalign - +RESIDUE_TYPES_MASK = RESIDUE_TYPES + [""] def compute_residue_esm(protein: Protein) -> torch.Tensor: esm_model, esm_alphabet = torch.hub.load( "facebookresearch/esm:main", "esm2_t33_650M_UR50D" @@ -36,7 +47,7 @@ def compute_residue_esm(protein: Protein) -> torch.Tensor: data = [] for chain, _ in itertools.groupby(protein.chain_index): sequence = "".join( - [RESIDUE_TYPES[aa] for aa in protein.aatype[protein.chain_index == chain]] + [RESIDUE_TYPES_MASK[aa] for aa in protein.aatype[protein.chain_index == chain]] ) data.append(("", sequence)) batch_tokens = esm_batch_converter(data)[2].cuda() @@ -45,7 +56,7 @@ def compute_residue_esm(protein: Protein) -> torch.Tensor: token_representations = results["representations"][esm_model.num_layers].cpu() residue_representations = [] for i, (_, sequence) in enumerate(data): - residue_representations.append(token_representations[i, 1 : len(sequence) + 1]) + residue_representations.append(token_representations[i, 1 : len(protein.aatype) + 1]) residue_esm = torch.cat(residue_representations, dim=0) assert residue_esm.size(0) == len(protein.aatype) return residue_esm @@ -62,6 +73,23 @@ def update_pos( ligand = update_mol_positions(ligand, pos[: ligand.GetNumAtoms()]) return protein, ligand +def predict_seq( + proba: torch.Tensor +) -> list : + tokens = torch.argmax(torch.softmax((torch.tensor(proba)), dim = -1), dim = -1) + RESIDUE_TYPES_NEW = ["X"] + RESIDUE_TYPES + return list(map(lambda i : RESIDUE_TYPES_NEW[i], tokens)) + +def update_seq( + protein: Protein, proba: torch.Tensor +) -> Protein: + tokens = torch.argmax(torch.softmax((torch.tensor(proba)), dim = -1), dim = -1) + RESIDUE_TYPES_NEW = ["X"] + RESIDUE_TYPES + sequence = "".join(map(lambda i : RESIDUE_TYPES_NEW[i], tokens)).lstrip("X").rstrip("X") + aatype = np.array([RESIDUE_TYPES.index(s) for s in sequence], dtype=np.int64) + protein = dataclasses.replace(protein, aatype = aatype) + return protein + def main(args): pl.seed_everything(args.seed, workers=True) @@ -75,24 +103,30 @@ def main(args): model = ProteinReDiffModel.load_from_checkpoint( args.ckpt_path, num_steps=args.num_steps ) + + model.training_mode = False + args.num_gpus = 1 + model.mask_prob = args.mask_prob + # Inputs if args.protein.endswith(".pdb"): protein = protein_from_pdb_file(args.protein) else: + protein = protein_from_sequence(args.protein) if args.ligand.endswith(".sdf") or args.ligand.endswith(".mol2"): ligand = mol_from_file(args.ligand) else: ligand = Chem.MolFromSmiles(args.ligand) - ligand = update_mol_positions(ligand, np.zeros((ligand.GetNumAtoms(), 3))) + ligand = update_mol_positions(ligand, np.zeros((ligand.GetNumAtoms(), 3))) total_num_atoms = len(protein.aatype) + ligand.GetNumAtoms() print(f"Total number of atoms: {total_num_atoms}") if total_num_atoms > 384: warnings.warn("Too many atoms. May take a long time for sample generation.") - + data = { **ligand_to_data(ligand), **protein_to_data(protein, residue_esm=compute_residue_esm(protein)), @@ -104,10 +138,11 @@ def main(args): trainer = pl.Trainer.from_argparse_args( args, accelerator="auto", + gpus = args.num_gpus, default_root_dir=args.output_dir, max_epochs=-1, ) - positions = trainer.predict( + results = trainer.predict( ## (NN) model, dataloaders=DataLoader( RepeatDataset(data, args.num_samples), @@ -116,13 +151,20 @@ def main(args): collate_fn=collate_fn, ), ) + + positions = [p[0] for p in results] + sequences = [s[1] for s in results] + positions = torch.cat(positions, dim=0).detach().cpu().numpy() + probabilities = torch.cat(sequences, dim=0).detach().cpu().numpy() + #torch.save(probabilities, "sampled_seq_gvp.pt") # can save embedding # Save samples sample_proteins, sample_ligands = [], [] tmscores = [] - for pos in positions: + for pos, seq_prob in zip(positions, probabilities): sample_protein, sample_ligand = update_pos(protein, ligand, pos) + sample_protein = update_seq(sample_protein, seq_prob) if ref_protein is None: warnings.warn( "Using the first sample as a reference. The resulting structures may be mirror images." @@ -155,10 +197,13 @@ def main(args): if __name__ == "__main__": parser = ArgumentParser() + parser.add_argument("--seed", type=int, default=1234) parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--num_workers", type=int, default=2) parser.add_argument("--num_steps", type=int, default=64) + parser.add_argument("--mask_prob", type=float, default=0.3) + parser.add_argument("--training_mode", action="store_true") parser.add_argument("-c", "--ckpt_path", type=Path, required=True) parser.add_argument("-o", "--output_dir", type=Path, required=True) parser.add_argument("-p", "--protein", type=str, required=True) diff --git a/scripts/__pycache__/__init__.cpython-39.pyc b/scripts/__pycache__/__init__.cpython-39.pyc index 6564106..16ea93f 100644 Binary files a/scripts/__pycache__/__init__.cpython-39.pyc and b/scripts/__pycache__/__init__.cpython-39.pyc differ diff --git a/scripts/__pycache__/predict_batch_seq_msk_inp.cpython-39.pyc b/scripts/__pycache__/predict_batch_seq_msk_inp.cpython-39.pyc new file mode 100644 index 0000000..1589563 Binary files /dev/null and b/scripts/__pycache__/predict_batch_seq_msk_inp.cpython-39.pyc differ diff --git a/scripts/predict_batch_seq_msk_inp.py b/scripts/predict_batch_seq_msk_inp.py index f443d06..e0a0551 100644 --- a/scripts/predict_batch_seq_msk_inp.py +++ b/scripts/predict_batch_seq_msk_inp.py @@ -5,16 +5,17 @@ from argparse import ArgumentParser from operator import itemgetter from pathlib import Path -from typing import Iterable, List, Union, Tuple +from typing import Iterable, List, Union, Tuple, Any import numpy as np +import random import pytorch_lightning as pl import torch from rdkit import Chem from torch.utils.data import DataLoader from ProteinReDiff.data import InferenceDataset, collate_fn, ligand_to_data, protein_to_data -from ProteinReDiff.model import ProteinReDiffModel ## (NN) +from ProteinReDiff.model import ProteinReDiffModel from ProteinReDiff.mol import get_mol_positions, mol_from_file, update_mol_positions from ProteinReDiff.protein import ( RESIDUE_TYPES, @@ -25,14 +26,36 @@ proteins_to_pdb_file, ) from ProteinReDiff.tmalign import run_tmalign + +torch.multiprocessing.set_start_method('fork') RESIDUE_TYPES_MASK = RESIDUE_TYPES + [""] -def compute_residue_esm(protein: Protein) -> torch.Tensor: - esm_model, esm_alphabet = torch.hub.load( - "facebookresearch/esm:main", "esm2_t33_650M_UR50D" - ) - esm_model.cuda().eval() - esm_batch_converter = esm_alphabet.get_batch_converter() + + + +esm_model = None +esm_batch_converter = None + +def load_esm_model(accelerator): + global esm_model, esm_batch_converter + if esm_model is None or esm_batch_converter is None: + esm_model, esm_alphabet = torch.hub.load( + "facebookresearch/esm:main", "esm2_t33_650M_UR50D" + ) + + # esm_model.cuda().eval() + if accelerator == "gpu": + esm_model.cuda().eval() + else: + esm_model.eval() + esm_batch_converter = esm_alphabet.get_batch_converter() + + + +def compute_residue_esm(protein: Protein, accelerator: str) -> torch.Tensor: + + global esm_model, esm_batch_converter + load_esm_model(accelerator) data = [] for chain, _ in itertools.groupby(protein.chain_index): @@ -40,7 +63,11 @@ def compute_residue_esm(protein: Protein) -> torch.Tensor: [RESIDUE_TYPES_MASK[aa] for aa in protein.aatype[protein.chain_index == chain]] ) data.append(("", sequence)) - batch_tokens = esm_batch_converter(data)[2].cuda() + # batch_tokens = esm_batch_converter(data)[2].cuda() + if accelerator == "gpu": + batch_tokens = esm_batch_converter(data)[2].cuda() + else: + batch_tokens = esm_batch_converter(data)[2] with torch.inference_mode(): results = esm_model(batch_tokens, repr_layers=[esm_model.num_layers]) token_representations = results["representations"][esm_model.num_layers].cpu() @@ -66,6 +93,24 @@ def proteins_from_fasta(fasta_file: Union[str, Path]): return proteins, names +def proteins_from_fasta_with_mask(fasta_file: Union[str, Path], mask_percent: float = 0.0): + names = [] + proteins = [] + sequences = [] + with open(fasta_file, "r") as f: + for line in f: + if line.startswith(">"): + name = line.lstrip(">").rstrip("\n").replace(" ","_") + names.append(name) + elif not line in ['\n', '\r\n']: + sequence = line.rstrip("\n") + sequence = mask_sequence_by_percent(sequence, mask_percent) + protein = protein_from_sequence(sequence) + proteins.append(protein) + sequences.append(sequence) + + return proteins, names, sequences + def parse_ligands(ligand_input: Union[str, Path, list]): ligands = [] if isinstance(ligand_input, list): @@ -110,9 +155,17 @@ def update_seq( protein = dataclasses.replace(protein, aatype = aatype) return protein +def mask_sequence_by_percent(seq, percentage=0.2): + aa_to_replace = random.sample(range(len(seq)), int(len(seq)*percentage)) + + output_aa = [char if idx not in aa_to_replace else 'X' for idx, char in enumerate(seq)] + masked_seq = ''.join(output_aa) + + return masked_seq def main(args): - pl.seed_everything(args.seed, workers=True) + pl.seed_everything(np.random.randint(999999), workers=True) + # Check if the directory exists if os.path.exists(args.output_dir): # Remove the existing directory @@ -123,45 +176,45 @@ def main(args): model = ProteinReDiffModel.load_from_checkpoint( args.ckpt_path, num_steps=args.num_steps ) - ## (NN) model.training_mode = False - args.num_gpus = 1 model.mask_prob = args.mask_prob ## (NN) # Inputs - proteins, names = proteins_from_fasta(args.fasta) - + proteins, names, masked_sequences = proteins_from_fasta_with_mask(args.fasta, args.mask_prob) + + with open(args.output_dir / "masked_sequences.fasta", "w") as f: + for i, (name, seq) in enumerate(zip(names, masked_sequences)): + f.write(">{}_sample_{}\n".format(name,i%args.num_samples)) + f.write("{}\n".format(seq)) + if args.ligand_file is None: - ligand_input = ["*"]*args.num_samples*len(names) + ligand_input = ["*"]*len(names) ligands = parse_ligands(ligand_input) else: ligands = parse_ligands(args.ligand_file) - # total_num_atoms = len(protein.aatype) + ligand.GetNumAtoms() - # print(f"Total number of atoms: {total_num_atoms}") - # if total_num_atoms > 400: - # warnings.warn("Too many atoms (> 400). May take a long time for sample generation.") datas = [] - for protein, ligand in zip(proteins, ligands): + for name, protein, ligand in zip(names,proteins, ligands): data = { **ligand_to_data(ligand), - **protein_to_data(protein, residue_esm=compute_residue_esm(protein)), + **protein_to_data(protein, residue_esm=compute_residue_esm(protein, args.accelerator)), } datas.extend([data]*args.num_samples) - # Generate samples - trainer = pl.Trainer.from_argparse_args( - args, - accelerator="auto", - gpus = args.num_gpus, - default_root_dir=args.output_dir, - max_epochs=-1, - ) + + trainer = pl.Trainer( + accelerator=args.accelerator, + devices=args.num_gpus, + default_root_dir=args.output_dir, + max_epochs=-1, + strategy='ddp' + + ) results = trainer.predict( model, dataloaders=DataLoader( @@ -169,15 +222,14 @@ def main(args): batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=collate_fn, - ), + ) ) - torch.save(results,"results.pt") - positions = [p[0] for p in results] ## (NN) - probabilities = [s[1] for s in results] ## (NN) + + + + probabilities = [s[1] for s in results] - # positions = torch.cat(positions, dim=0).detach().cpu().numpy() - # probabilities = torch.cat(probabilities, dim=0).detach().cpu().numpy() #(NN) names = [n for n in names for _ in range(args.num_samples)] @@ -187,47 +239,16 @@ def main(args): sequence = predict_seq(seq_prob.squeeze()) f.write("{}\n".format(sequence)) - # Save samples - # sample_proteins, sample_ligands = [], [] - # tmscores = [] - # for pos, seq_prob in zip(positions, probabilities): - # sample_protein, sample_ligand = update_pos(protein, ligand, pos) - # sample_protein = update_seq(sample_protein, seq_prob) - # if ref_protein is None: - # warnings.warn( - # "Using the first sample as a reference. The resulting structures may be mirror images." - # ) - # ref_protein = sample_protein - # tmscore, t, R = max( - # run_tmalign(sample_protein, ref_protein), - # run_tmalign(sample_protein, ref_protein, mirror=True), - # key=itemgetter(0), - # ) - # sample_proteins.append( - # dataclasses.replace( - # sample_protein, atom_pos=t + sample_protein.atom_pos @ R - # ) - # ) - # sample_ligands.append( - # update_mol_positions( - # sample_ligand, t + get_mol_positions(sample_ligand) @ R - # ) - # ) - # tmscores.append(tmscore) - # proteins_to_pdb_file(sample_proteins, args.output_dir / "sample_protein.pdb") - # with Chem.SDWriter(str(args.output_dir / "sample_ligand.sdf")) as w: - # for sample_ligand in sample_ligands: - # w.write(sample_ligand) - # with open(args.output_dir / "sample_tmscores.txt", "w") as f: - # for tmscore in tmscores: - # f.write(str(tmscore) + "\n") + if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument("--seed", type=int, default=1234) + # parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--accelerator", type=str, default="gpu") parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--num_gpus", type=int, default=1) parser.add_argument("--num_workers", type=int, default=torch.get_num_threads()) parser.add_argument("--num_steps", type=int, default=64) parser.add_argument("--mask_prob", type=float, default=0.3) diff --git a/scripts/predict_batch_strc_msk_inp.py b/scripts/predict_batch_strc_msk_inp.py index 3321530..f948320 100644 --- a/scripts/predict_batch_strc_msk_inp.py +++ b/scripts/predict_batch_strc_msk_inp.py @@ -51,25 +51,10 @@ def load_esm_model(accelerator): esm_model.eval() esm_batch_converter = esm_alphabet.get_batch_converter() -# esm_model, esm_alphabet = torch.hub.load( -# "facebookresearch/esm:main", "esm2_t33_650M_UR50D" -# ) -# esm_model.cuda().eval() -# esm_batch_converter = esm_alphabet.get_batch_converter() def compute_residue_esm(protein: Protein, accelerator: str) -> torch.Tensor: - # esm_model, esm_alphabet = torch.hub.load( - # "facebookresearch/esm:main", "esm2_t33_650M_UR50D" - # ) - # esm_model.cuda().eval() - # esm_batch_converter = esm_alphabet.get_batch_converter() + global esm_model, esm_batch_converter - # if esm_model is None or esm_batch_converter is None: - # esm_model, esm_alphabet = torch.hub.load( - # "facebookresearch/esm:main", "esm2_t33_650M_UR50D" - # ) - # esm_model.cuda().eval() - # esm_batch_converter = esm_alphabet.get_batch_converter() load_esm_model(accelerator) data = [] @@ -179,7 +164,7 @@ def mask_sequence_by_percent(seq, percentage=0.2): return masked_seq def main(args): - pl.seed_everything(np.random.randint(999999999), workers=True) + pl.seed_everything(np.random.randint(99999), workers=True) # Check if the directory exists if os.path.exists(args.output_dir): @@ -210,11 +195,6 @@ def main(args): else: ligands = parse_ligands(args.ligand_file) - # total_num_atoms = len(protein.aatype) + ligand.GetNumAtoms() - # print(f"Total number of atoms: {total_num_atoms}") - # if total_num_atoms > 400: - # warnings.warn("Too many atoms (> 400). May take a long time for sample generation.") - datas = [] for name, protein, ligand in zip(names,proteins, ligands): data = { @@ -245,24 +225,8 @@ def main(args): ) - - #torch.save(results,"results.pt") - positions = [p[0] for p in results] ## (NN) - probabilities = [s[1] for s in results] ## (NN) - - # positions = torch.cat([p[0] for p in results], dim=-1).detach().cpu().numpy() - # probabilities = torch.cat(probabilities, dim=0).detach().cpu().numpy() #(NN) - - - - # with open(args.output_dir / "sample_sequences.fasta", "w") as f: - # for i, (name, seq_prob) in enumerate(zip(names, probabilities)): - # f.write(">{}_sample_{}\n".format(name,i%args.num_samples)) - # sequence = predict_seq(seq_prob.squeeze()) - # f.write("{}\n".format(sequence)) - - # Save samples - # Repeat samples by num_samples + positions = [p[0] for p in results] + probabilities = [s[1] for s in results] proteins, ligands, names = [protein for protein in proteins for _ in range(args.num_samples)],\ [ligand for ligand in ligands for _ in range(args.num_samples)], \ [name for name in names for _ in range(args.num_samples)] diff --git a/train.py b/train.py index 9cae8aa..fbdef2e 100644 --- a/train.py +++ b/train.py @@ -17,8 +17,8 @@ from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import WandbLogger -from ProteinReDiff.data import PDBbindDataModule -from ProteinReDiff.model import DiffusionModel +from ProteinReDiff.data import PDBDataModule +from ProteinReDiff.model import ProteinReDiffModel @@ -29,14 +29,13 @@ def main(args): rmtree(args.save_dir) args.save_dir.mkdir(parents=True) - datamodule = PDBbindDataModule.from_argparse_args(args) - model = DiffusionModel(args) + datamodule = PDBDataModule.from_argparse_args(args) + model = ProteinReDiffModel(args) trainer = pl.Trainer.from_argparse_args( args, accelerator="auto", precision=16, strategy="ddp_find_unused_parameters_false", - #logger=WandbLogger(save_dir=args.save_dir, project="DiffusionProteinLigand"), callbacks=[ ModelCheckpoint( filename="{epoch:03d}-{val_loss:.2f}", @@ -53,10 +52,11 @@ def main(args): if __name__ == "__main__": parser = ArgumentParser() - parser = PDBbindDataModule.add_argparse_args(parser) - parser = DiffusionModel.add_argparse_args(parser) + parser = PDBDataModule.add_argparse_args(parser) + parser = ProteinReDiffModel.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--num_gpus", type = int, default = 1) parser.add_argument("--save_dir", type=Path, required=True) args = parser.parse_args() diff --git a/train_from_ckpt.py b/train_from_ckpt.py index 871f8c5..6cce808 100644 --- a/train_from_ckpt.py +++ b/train_from_ckpt.py @@ -15,8 +15,8 @@ from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import WandbLogger -from ProteinReDiff.data import PDBbindDataModule -from ProteinReDiff.model import DiffusionModel +from ProteinReDiff.data import PDBDataModule +from ProteinReDiff.model import ProteinReDiffModel @@ -24,15 +24,14 @@ def main(args): pl.seed_everything(args.seed, workers=True) args.save_dir.mkdir(parents=True) - datamodule = PDBbindDataModule.from_argparse_args(args) - model = DiffusionModel(args) + datamodule = PDBDataModule.from_argparse_args(args) + model = ProteinReDiffModel(args) trainer = pl.Trainer.from_argparse_args( args, accelerator="auto", precision=16, strategy="ddp_find_unused_parameters_false", resume_from_checkpoint=args.trained_ckpt, - #logger=WandbLogger(save_dir=args.save_dir, project="DiffusionProteinLigand"), callbacks=[ ModelCheckpoint( filename="{epoch:03d}-{val_loss:.2f}", @@ -49,8 +48,8 @@ def main(args): if __name__ == "__main__": parser = ArgumentParser() - parser = PDBbindDataModule.add_argparse_args(parser) - parser = DiffusionModel.add_argparse_args(parser) + parser = PDBDataModule.add_argparse_args(parser) + parser = ProteinReDiffModel.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--seed", type=int, default=1234) parser.add_argument("--save_dir", type=Path, required=True)