Skip to content

Commit

Permalink
Fix import error
Browse files Browse the repository at this point in the history
  • Loading branch information
thanhduy1842001 committed Apr 20, 2024
2 parents 353312c + 3c3d80f commit 7679c59
Show file tree
Hide file tree
Showing 33 changed files with 29,459 additions and 1,104 deletions.
Binary file added Equivariant-Diffusion.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
23 changes: 15 additions & 8 deletions ProteinReDiff/data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
#### CITE
"""
Adapted from Nakata, S., Mori, Y. & Tanaka, S.
End-to-end protein–ligand complex structure generation with diffusion-based generative models.
BMC Bioinformatics 24, 233 (2023).
https://doi.org/10.1186/s12859-023-05354-5
Repository: https://github.com/shuyana/DiffusionProteinLigand
"""

from pathlib import Path
from typing import Any, List, Mapping, Optional, Sequence, Union
from itertools import cycle, islice
Expand All @@ -15,7 +23,7 @@
from .mol import get_mol_positions
from .protein import Protein, protein_to_ca_mol

from .dev import spatial_mask ### (NN)
from .dev import spatial_mask


def ligand_to_data(ligand: Chem.Mol, **kwargs: Any) -> Mapping[str, Any]:
Expand Down Expand Up @@ -113,7 +121,7 @@ def collate_fn(data_list: Sequence[Mapping[str, Any]]) -> Mapping[str, Any]:
else:
batch[k] = default_collate([data[k] for data in data_list])

### (NN): Adding spatial masking
### Adding spatial masking
# if "residue_spatial_mask" not in data_list[0]:
# feat_pad = (0, 0) * (data_list[0]["residue_mask"].dim() - 1)
# batch["residue_spatial_mask"] = default_collate(
Expand Down Expand Up @@ -190,7 +198,6 @@ def load_data(self):
for pdb_id in tqdm(self.pdb_ids):
ligand_data = torch.load(self.root_dir / pdb_id / "ligand_data.pt")
protein_data = torch.load(self.root_dir / pdb_id / "protein_data.pt")
# if (protein_data['num_residues'] + ligand_data["num_atoms"]) <=400:
yield {"pdb_id": pdb_id, **ligand_data, **protein_data}
def get_stream(self):
return cycle(self.load_data())
Expand All @@ -214,20 +221,20 @@ def __init__(

def setup(self, stage: Optional[str] = None) -> None:
self.train_pdb_ids: List[str] = []
with open(self.data_dir / "nn_short_train_pdb_ids", "r") as f:
with open(self.data_dir / "PRD_train_pdb_ids", "r") as f:
self.train_pdb_ids.extend(line.strip() for line in f.readlines())
self.val_pdb_ids: List[str] = []
with open(self.data_dir / "val_pdb_ids", "r") as f:
with open(self.data_dir / "PRD_val_pdb_ids", "r") as f:
self.val_pdb_ids.extend(line.strip() for line in f.readlines())
self.test_pdb_ids: List[str] = []
with open(self.data_dir / "nn_test_pdb_ids", "r") as f:
with open(self.data_dir / "PRD_test_pdb_ids", "r") as f:
self.test_pdb_ids.extend(line.strip() for line in f.readlines())

def train_dataloader(self) -> DataLoader:
return DataLoader(
PDBbindDataset(self.cache_dir, self.train_pdb_ids),
batch_size=self.batch_size,
shuffle=True, ## (NN)
shuffle=True,
num_workers=self.num_workers,
collate_fn=collate_fn,
# prefetch_factor = 500,
Expand Down
147 changes: 0 additions & 147 deletions ProteinReDiff/dev.py

This file was deleted.

3 changes: 3 additions & 0 deletions ProteinReDiff/difffusion.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Beta parameter diffusion scheduler
# Adapted from https://github.com/aqlaboratory/genie

import math
import torch

Expand Down
1 change: 0 additions & 1 deletion ProteinReDiff/mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def __init__(self,
inf = 1e10
):
super().__init__()
# self.max_p = max_p
self.inf = inf


Expand Down
Loading

0 comments on commit 7679c59

Please sign in to comment.