Skip to content

Commit f31594d

Browse files
committed
add Chainsaw option for embedding mode
1 parent 2f6cb65 commit f31594d

File tree

3 files changed

+24
-6
lines changed

3 files changed

+24
-6
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ progres embed -l filepaths.txt -o searchdb.pt
124124
- `-l` is a text file with information on one structure per line, each of which will be one entry in the output. White space should separate the file path to the structure and the domain name, with optionally any additional text being treated as a note for the notes column of the results.
125125
- `-o` is the output file path for the PyTorch file containing a dictionary with the embeddings and associated data. It can be read in with `torch.load`.
126126
- `-f` determines the file format of each structure as above (`guess`, `pdb`, `mmcif`, `mmtf` or `coords`).
127+
- `-c` indicates to split each structure into domains with Chainsaw to allow searching against each domain separately. If no domains are found with Chainsaw for a structure, it will not be added. Only the first chain in each file is considered. Running Chainsaw may take a few seconds.
127128

128129
Again, the structures should correspond to single protein domains.
129130
The embeddings are stored as Float16, which has no noticeable effect on search performance.

bin/progres

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ parser_embed.add_argument("-o", "--outputfile", required=True,
7777
parser_embed.add_argument("-f", "--fileformat",
7878
choices=["guess", "pdb", "mmcif", "mmtf", "coords"], default="guess",
7979
help="file format of the structures, by default guessed from the file extension")
80+
parser_embed.add_argument("-c", "--chainsaw", default=False, action="store_true",
81+
help=("split each structure into domains with Chainsaw to allow searching "
82+
"against each domain separately"))
8083
parser_embed.add_argument("-d", "--device", default="cpu",
8184
help="device to run on, default is \"cpu\"")
8285

@@ -107,7 +110,7 @@ def main():
107110
elif args.mode == "embed":
108111
from progres import progres_embed
109112
progres_embed(structurelist=args.structurelist, outputfile=args.outputfile,
110-
fileformat=args.fileformat, device=args.device)
113+
fileformat=args.fileformat, chainsaw=args.chainsaw, device=args.device)
111114
else:
112115
print("No mode selected, run \"progres -h\" to see help", file=sys.stderr)
113116

progres/progres.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -683,20 +683,34 @@ def progres_score_print(structure1, structure2, fileformat1="guess",
683683
score = progres_score(structure1, structure2, fileformat1, fileformat2, device)
684684
print(score)
685685

686-
def progres_embed(structurelist, outputfile, fileformat="guess", device="cpu",
686+
def progres_embed(structurelist, outputfile, fileformat="guess", chainsaw=False, device="cpu",
687687
batch_size=None, float_type=torch.float16):
688688
download_data_if_required()
689689

690-
fps, domids, notes = [], [], []
690+
fps, domids_fp, notes_fp = [], [], []
691691
with open(structurelist) as f:
692692
for line in f.readlines():
693693
cols = line.strip().split(None, 2)
694694
fps.append(cols[0])
695-
domids.append(cols[1])
696-
notes.append(cols[2] if len(cols) > 2 else "-")
695+
domids_fp.append(cols[1])
696+
notes_fp.append(cols[2] if len(cols) > 2 else "-")
697697

698698
model = load_trained_model(device)
699-
data_set = StructureDataset(fps, fileformat, model, device)
699+
data_set = StructureDataset(fps, fileformat, model, device, chainsaw)
700+
if chainsaw:
701+
domids, notes = [], []
702+
i, dom_i = 0, 1
703+
for fp, domid, note in zip(fps, domids_fp, notes_fp):
704+
while data_set.file_paths[i] == fp:
705+
domids.append(f"{domid}_D{dom_i}")
706+
notes.append(f"{note} - domain {dom_i} ({data_set.res_ranges[i]})")
707+
i += 1
708+
dom_i += 1
709+
dom_i = 1
710+
else:
711+
domids, notes = domids_fp, notes_fp
712+
assert len(domids) == len(notes) == len(data_set)
713+
700714
if batch_size is None:
701715
batch_size = get_batch_size(device)
702716
data_loader = DataLoader(

0 commit comments

Comments
 (0)