Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tmol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from tmol.io.pose_stack_from_rosettafold2 import ( # noqa: F401
pose_stack_from_rosettafold2,
pose_stack_to_rosettafold2,
canonical_form_from_rosettafold2,
canonical_ordering_for_rosettafold2,
packed_block_types_for_rosettafold2,
Expand All @@ -32,6 +33,11 @@
from tmol.score import beta2016_score_function # noqa: F401
from tmol.score.score_function import ScoreFunction # noqa: F401

from tmol.optimization.sfxn_modules import (
CartesianSfxnNetwork as cart_sfxn_network,
) # noqa: F401

from tmol.optimization.lbfgs_armijo import LBFGS_Armijo as lbfgs_armijo # noqa: F401

try:
__version__ = version("tmol")
Expand Down
144 changes: 144 additions & 0 deletions tmol/io/pose_stack_from_rosettafold2.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,88 @@ def canonical_form_from_rosettafold2(
)


def pose_stack_to_rosettafold2_with_suppressed(pose_stack, chainlens):
from tmol.io.pose_stack_deconstruction import canonical_form_from_pose_stack

device = pose_stack.device
n_poses = 1 # RF2 does not presently do batch processing
max_n_res = sum(chainlens)
max_n_ats = 27

rf2_pose_ind_for_atom = (
torch.arange(n_poses, dtype=torch.int64, device=device)
.reshape(-1, 1, 1)
.expand(-1, max_n_res, max_n_ats)
)
rf2_res_ind_for_atom = (
torch.arange(max_n_res, dtype=torch.int64, device=device)
.reshape(1, -1, 1)
.expand(n_poses, -1, max_n_ats)
)

co = canonical_ordering_for_rosettafold2()
(
_, # rf22t_rtmap,
rf22t_atmap,
rf2_at_is_real_map,
supress_atom_for_nterm,
hydrogens,
h_to_h1,
) = _get_tmol_2_rf2_mappings(device)

# torch.set_printoptions(threshold=10_000, linewidth=256)

canonical_form = canonical_form_from_pose_stack(co, pose_stack)

seq = canonical_form[1]
atom_mapping = rf22t_atmap[seq]
hydro = hydrogens[seq]
h1s = h_to_h1[seq]
supressed = supress_atom_for_nterm[seq]
rf2_at_is_real = rf2_at_is_real_map[seq]

rf2_coords = torch.full(
(n_poses, max_n_res, max_n_ats, 3),
numpy.NaN,
dtype=torch.float32,
device=device,
)

rf2_coords[rf2_at_is_real] = canonical_form[2][
rf2_pose_ind_for_atom[rf2_at_is_real],
rf2_res_ind_for_atom[rf2_at_is_real],
atom_mapping[rf2_at_is_real],
]

# Find any NaN hydrogens and copy from the H1 instead
nans = torch.isnan(rf2_coords).any(-1)
terminal_hs = torch.logical_and(nans, hydro)
rf2_coords[terminal_hs] = canonical_form[2][
rf2_pose_ind_for_atom[terminal_hs],
rf2_res_ind_for_atom[terminal_hs],
h1s[terminal_hs.any(-1)],
]

suppressed_mapped = torch.full(
(n_poses, max_n_res, max_n_ats),
False,
dtype=torch.bool,
device=device,
)
suppressed_mapped[rf2_at_is_real] = supressed[
rf2_pose_ind_for_atom[rf2_at_is_real],
rf2_res_ind_for_atom[rf2_at_is_real],
atom_mapping[rf2_at_is_real],
]
return rf2_coords, torch.logical_and(
rf2_at_is_real, torch.logical_not(suppressed_mapped)
)


def pose_stack_to_rosettafold2(pose_stack, chainlens):
return pose_stack_to_rosettafold2_with_suppressed(pose_stack, chainlens)[0]


@toolz.functoolz.memoize
def _paramdb_for_rosettafold2() -> ParameterDatabase:
"""Construct the paramdb representing the subset of residues that
Expand Down Expand Up @@ -245,3 +327,65 @@ def _get_rf2_2_tmol_mappings(device: torch.device):
tmol_ind = co.restypes_atom_index_mapping[i_3lc][atname.strip()]
supress_atom_at_nterm[i, tmol_ind] = True
return rt_map, atname_map, at_is_real, supress_atom_at_nterm


@toolz.functoolz.memoize
def _get_tmol_2_rf2_mappings(device: torch.device):
"""Same logic is the RF2->tmol function, but additionally provides
a tensor marking the hydrogens in the RF2 index space, as well as
a tensor giving the tmol 1H index for a residue when indexed by
the RF2 residue type index
"""

co = canonical_ordering_for_rosettafold2()
from tmol.extern.rosettafold2.chemical import (
num2aa,
aa2long,
)

rf2_atom_names_for_name3s = {
x: [at.strip() if at is not None else "" for at in y]
for x, y in zip(num2aa, aa2long)
}

(rt_map, atname_map, at_is_real) = co.create_src_2_tmol_mappings(
num2aa, rf2_atom_names_for_name3s, device
)

src_max_n_ats = len(rf2_atom_names_for_name3s[num2aa[0]])
hydrogens = torch.zeros(
(
len(num2aa),
src_max_n_ats,
),
dtype=torch.bool,
device=device,
)
h_to_h1 = torch.full(
(len(num2aa),),
-1,
dtype=torch.int64,
device=device,
)

# also want to turn off n-term "H" atoms
supress_atom_at_nterm = torch.zeros(
(
len(num2aa),
co.max_n_canonical_atoms,
),
dtype=torch.bool,
device=device,
)

for i, i_3lc in enumerate(num2aa):
if i_3lc not in co.restype_io_equiv_classes:
continue
for j, atname in enumerate(rf2_atom_names_for_name3s[i_3lc]):
if atname.strip() == "H":
tmol_ind = co.restypes_atom_index_mapping[i_3lc][atname.strip()]
supress_atom_at_nterm[i, tmol_ind] = True

hydrogens[i, j] = True
h_to_h1[i] = co.restypes_atom_index_mapping[i_3lc]["1H"]
return rt_map, atname_map, at_is_real, supress_atom_at_nterm, hydrogens, h_to_h1
21 changes: 21 additions & 0 deletions tmol/tests/io/test_pose_stack_from_rosettafold2.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import torch

from tmol.tests.score.common.test_energy_term import assert_allclose

from tmol.io.pose_stack_from_rosettafold2 import (
pose_stack_from_rosettafold2,
pose_stack_to_rosettafold2_with_suppressed,
canonical_form_from_rosettafold2,
_paramdb_for_rosettafold2,
canonical_ordering_for_rosettafold2,
Expand Down Expand Up @@ -48,6 +50,25 @@ def test_multi_chain_rosettafold2_pose_stack_construction(
assert ps.packed_block_types is pbt


def test_from_to_rosettafold2(rosettafold2_ubq_pred, torch_device):
rosettafold2_ubq_pred["chainlens"] = [76]

# RF2->tmol
ps = pose_stack_from_rosettafold2(**rosettafold2_ubq_pred)

# tmol->RF2
rf2ubq, rf2_ats = pose_stack_to_rosettafold2_with_suppressed(
ps, rosettafold2_ubq_pred["chainlens"]
)

assert_allclose(
rosettafold2_ubq_pred["xyz"].unsqueeze(0)[rf2_ats].cpu(),
rf2ubq[rf2_ats].cpu(),
1e-5,
1e-3,
)


def test_create_canonical_form_from_rosettafold2_ubq_stability(
rosettafold2_ubq_pred, torch_device
):
Expand Down