Skip to content

Commit

Permalink
fix module issues
Browse files Browse the repository at this point in the history
  • Loading branch information
ndnng committed May 14, 2024
1 parent 67fe73c commit ae4d0d2
Show file tree
Hide file tree
Showing 20 changed files with 168 additions and 139 deletions.
Binary file modified ProteinReDiff/__pycache__/data.cpython-39.pyc
Binary file not shown.
Binary file modified ProteinReDiff/__pycache__/difffusion.cpython-39.pyc
Binary file not shown.
Binary file modified ProteinReDiff/__pycache__/features.cpython-39.pyc
Binary file not shown.
Binary file modified ProteinReDiff/__pycache__/mask_utils.cpython-39.pyc
Binary file not shown.
Binary file modified ProteinReDiff/__pycache__/model.cpython-39.pyc
Binary file not shown.
Binary file modified ProteinReDiff/__pycache__/modules.cpython-39.pyc
Binary file not shown.
Binary file modified ProteinReDiff/__pycache__/mol.cpython-39.pyc
Binary file not shown.
Binary file modified ProteinReDiff/__pycache__/protein.cpython-39.pyc
Binary file not shown.
Binary file modified ProteinReDiff/__pycache__/tmalign.cpython-39.pyc
Binary file not shown.
Binary file modified ProteinReDiff/__pycache__/utils.cpython-39.pyc
Binary file not shown.
12 changes: 6 additions & 6 deletions ProteinReDiff/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def __getitem__(self, index: int) -> Mapping[str, Any]:
return self.data[index]


class PDBbindDataset(Dataset):
class PDBDataset(Dataset):
def __init__(self, root_dir: Union[str, Path], pdb_ids: Sequence[str]):
super().__init__()
if isinstance(root_dir, str):
Expand Down Expand Up @@ -203,7 +203,7 @@ def get_stream(self):
def __iter__(self):
return self.get_stream()

class PDBbindDataModule(pl.LightningDataModule):
class PDBDataModule(pl.LightningDataModule):
def __init__(
self,
data_dir: Union[str, Path] = "data",
Expand All @@ -214,7 +214,7 @@ def __init__(
if isinstance(data_dir, str):
data_dir = Path(data_dir)
self.data_dir = data_dir
self.cache_dir = data_dir / "PDBBind_processed_cache"
self.cache_dir = data_dir / "PDB_processed_cache"
self.batch_size = batch_size
self.num_workers = num_workers

Expand All @@ -231,7 +231,7 @@ def setup(self, stage: Optional[str] = None) -> None:

def train_dataloader(self) -> DataLoader:
return DataLoader(
PDBbindDataset(self.cache_dir, self.train_pdb_ids),
PDBDataset(self.cache_dir, self.train_pdb_ids),
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
Expand All @@ -242,7 +242,7 @@ def train_dataloader(self) -> DataLoader:

def val_dataloader(self) -> DataLoader:
return DataLoader(
PDBbindDataset(self.cache_dir, self.val_pdb_ids),
PDBDataset(self.cache_dir, self.val_pdb_ids),
batch_size=self.batch_size,
num_workers=self.num_workers,
collate_fn=collate_fn,
Expand All @@ -251,7 +251,7 @@ def val_dataloader(self) -> DataLoader:

def test_dataloader(self) -> DataLoader:
return DataLoader(
PDBbindDataset(self.cache_dir, self.test_pdb_ids),
PDBDataset(self.cache_dir, self.test_pdb_ids),
batch_size=self.batch_size,
num_workers=self.num_workers,
collate_fn=collate_fn,
Expand Down
Binary file modified ProteinReDiff/models/__pycache__/AF2_modules.cpython-39.pyc
Binary file not shown.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ python -m scripts.predict_batch_seq_msk_inp \

Download the PDBbind dataset from https://zenodo.org/record/6408497 and unzip it.

Move the resulting PDBBind_processed directory to data/.
Move the resulting PDBBind_processed directory to ./data/.

Preprocess the dataset:
```bash
Expand All @@ -140,7 +140,7 @@ python train.py \
--num_blocks 4
```

Please modify the batch_size and accumulate_grad_batches arguments according to your machine(s).
Please modify the batch_size, gpus, and accumulate_grad_batches arguments according to your machine(s). Also, use the flag data_dir for directory containing your weights (ie. "./data") and save_dir for directory to save training log files.

Default values can be used to reproduce the settings used in our paper:

Expand Down
59 changes: 52 additions & 7 deletions generate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
"""
Adapted from Nakata, S., Mori, Y. & Tanaka, S.
End-to-end protein–ligand complex structure generation with diffusion-based generative models.
BMC Bioinformatics 24, 233 (2023).
https://doi.org/10.1186/s12859-023-05354-5
Repository: https://github.com/shuyana/DiffusionProteinLigand
"""

import dataclasses
import itertools
import warnings
Expand All @@ -18,14 +28,15 @@
from ProteinReDiff.mol import get_mol_positions, mol_from_file, update_mol_positions
from ProteinReDiff.protein import (
RESIDUE_TYPES,
RESIDUE_TYPE_INDEX,
Protein,
protein_from_pdb_file,
protein_from_sequence,
proteins_to_pdb_file,
)
from ProteinReDiff.tmalign import run_tmalign


RESIDUE_TYPES_MASK = RESIDUE_TYPES + ["<mask>"]
def compute_residue_esm(protein: Protein) -> torch.Tensor:
esm_model, esm_alphabet = torch.hub.load(
"facebookresearch/esm:main", "esm2_t33_650M_UR50D"
Expand All @@ -36,7 +47,7 @@ def compute_residue_esm(protein: Protein) -> torch.Tensor:
data = []
for chain, _ in itertools.groupby(protein.chain_index):
sequence = "".join(
[RESIDUE_TYPES[aa] for aa in protein.aatype[protein.chain_index == chain]]
[RESIDUE_TYPES_MASK[aa] for aa in protein.aatype[protein.chain_index == chain]]
)
data.append(("", sequence))
batch_tokens = esm_batch_converter(data)[2].cuda()
Expand All @@ -45,7 +56,7 @@ def compute_residue_esm(protein: Protein) -> torch.Tensor:
token_representations = results["representations"][esm_model.num_layers].cpu()
residue_representations = []
for i, (_, sequence) in enumerate(data):
residue_representations.append(token_representations[i, 1 : len(sequence) + 1])
residue_representations.append(token_representations[i, 1 : len(protein.aatype) + 1])
residue_esm = torch.cat(residue_representations, dim=0)
assert residue_esm.size(0) == len(protein.aatype)
return residue_esm
Expand All @@ -62,6 +73,23 @@ def update_pos(
ligand = update_mol_positions(ligand, pos[: ligand.GetNumAtoms()])
return protein, ligand

def predict_seq(
proba: torch.Tensor
) -> list :
tokens = torch.argmax(torch.softmax((torch.tensor(proba)), dim = -1), dim = -1)
RESIDUE_TYPES_NEW = ["X"] + RESIDUE_TYPES
return list(map(lambda i : RESIDUE_TYPES_NEW[i], tokens))

def update_seq(
protein: Protein, proba: torch.Tensor
) -> Protein:
tokens = torch.argmax(torch.softmax((torch.tensor(proba)), dim = -1), dim = -1)
RESIDUE_TYPES_NEW = ["X"] + RESIDUE_TYPES
sequence = "".join(map(lambda i : RESIDUE_TYPES_NEW[i], tokens)).lstrip("X").rstrip("X")
aatype = np.array([RESIDUE_TYPES.index(s) for s in sequence], dtype=np.int64)
protein = dataclasses.replace(protein, aatype = aatype)
return protein


def main(args):
pl.seed_everything(args.seed, workers=True)
Expand All @@ -75,24 +103,30 @@ def main(args):
model = ProteinReDiffModel.load_from_checkpoint(
args.ckpt_path, num_steps=args.num_steps
)

model.training_mode = False
args.num_gpus = 1
model.mask_prob = args.mask_prob


# Inputs
if args.protein.endswith(".pdb"):
protein = protein_from_pdb_file(args.protein)
else:

protein = protein_from_sequence(args.protein)

if args.ligand.endswith(".sdf") or args.ligand.endswith(".mol2"):
ligand = mol_from_file(args.ligand)
else:
ligand = Chem.MolFromSmiles(args.ligand)
ligand = update_mol_positions(ligand, np.zeros((ligand.GetNumAtoms(), 3)))
ligand = update_mol_positions(ligand, np.zeros((ligand.GetNumAtoms(), 3)))

total_num_atoms = len(protein.aatype) + ligand.GetNumAtoms()
print(f"Total number of atoms: {total_num_atoms}")
if total_num_atoms > 384:
warnings.warn("Too many atoms. May take a long time for sample generation.")

data = {
**ligand_to_data(ligand),
**protein_to_data(protein, residue_esm=compute_residue_esm(protein)),
Expand All @@ -104,10 +138,11 @@ def main(args):
trainer = pl.Trainer.from_argparse_args(
args,
accelerator="auto",
gpus = args.num_gpus,
default_root_dir=args.output_dir,
max_epochs=-1,
)
positions = trainer.predict(
results = trainer.predict( ## (NN)
model,
dataloaders=DataLoader(
RepeatDataset(data, args.num_samples),
Expand All @@ -116,13 +151,20 @@ def main(args):
collate_fn=collate_fn,
),
)

positions = [p[0] for p in results]
sequences = [s[1] for s in results]

positions = torch.cat(positions, dim=0).detach().cpu().numpy()
probabilities = torch.cat(sequences, dim=0).detach().cpu().numpy()
#torch.save(probabilities, "sampled_seq_gvp.pt") # can save embedding

# Save samples
sample_proteins, sample_ligands = [], []
tmscores = []
for pos in positions:
for pos, seq_prob in zip(positions, probabilities):
sample_protein, sample_ligand = update_pos(protein, ligand, pos)
sample_protein = update_seq(sample_protein, seq_prob)
if ref_protein is None:
warnings.warn(
"Using the first sample as a reference. The resulting structures may be mirror images."
Expand Down Expand Up @@ -155,10 +197,13 @@ def main(args):

if __name__ == "__main__":
parser = ArgumentParser()

parser.add_argument("--seed", type=int, default=1234)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--num_workers", type=int, default=2)
parser.add_argument("--num_steps", type=int, default=64)
parser.add_argument("--mask_prob", type=float, default=0.3)
parser.add_argument("--training_mode", action="store_true")
parser.add_argument("-c", "--ckpt_path", type=Path, required=True)
parser.add_argument("-o", "--output_dir", type=Path, required=True)
parser.add_argument("-p", "--protein", type=str, required=True)
Expand Down
Binary file modified scripts/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit ae4d0d2

Please sign in to comment.