diff --git a/.buildkite/bin/linting b/.buildkite/bin/linting index e8d0e00b3..204463f4b 100755 --- a/.buildkite/bin/linting +++ b/.buildkite/bin/linting @@ -1,14 +1,13 @@ #!/bin/bash #http://redsymbol.net/articles/unofficial-bash-strict-mode/ +source ~/.bashrc +conda activate $BUILDKITE_BUILD_ID + set -euo pipefail IFS=$'\n\t' set +e - -source ~/.bashrc -conda activate $BUILDKITE_BUILD_ID - set -x error=0 diff --git a/environments/linux-cuda/env.yml b/environments/linux-cuda/env.yml index 7f659f05a..7f8395a00 100644 --- a/environments/linux-cuda/env.yml +++ b/environments/linux-cuda/env.yml @@ -1,10 +1,11 @@ name: tmol channels: - - nvidia/label/cuda-12.1.1 + - nvidia/label/cuda-12.8.0 - conda-forge dependencies: - python=3.11 - cuda + - nvtx - pip - pip: - -r requirements-linux-cuda.txt diff --git a/environments/macos-cpu/env.yml b/environments/macos-cpu/env.yml new file mode 100644 index 000000000..ab356dc27 --- /dev/null +++ b/environments/macos-cpu/env.yml @@ -0,0 +1,8 @@ +name: tmol +channels: + - conda-forge +dependencies: + - python=3.11 + - pip + - pip: + - -r requirements-macos-cpu.txt diff --git a/environments/macos-cpu/requirements-dev-macos-cpu.txt b/environments/macos-cpu/requirements-dev-macos-cpu.txt new file mode 100644 index 000000000..ee23f1274 --- /dev/null +++ b/environments/macos-cpu/requirements-dev-macos-cpu.txt @@ -0,0 +1,181 @@ +# +# This file is autogenerated by pip-compile with Python 3.11 +# by the following command: +# +# pip-compile --constraint=requirements-macos-cpu.txt --no-emit-index-url --output-file=requirements-dev-macos-cpu.txt ../../requirements-dev.in +# +black==24.1.1 + # via -r ../../requirements-dev.in +build==1.0.3 + # via pip-tools +certifi==2024.2.2 + # via + # -c requirements-macos-cpu.txt + # requests +cfgv==3.4.0 + # via pre-commit +charset-normalizer==3.3.2 + # via + # -c requirements-macos-cpu.txt + # requests +clang-format==17.0.6 + # via -r ../../requirements-dev.in +click==8.1.7 + # via + # black + # pip-tools +codecov==2.1.13 + # via -r ../../requirements-dev.in +colorama==0.4.6 + # via pytest-watch +contourpy==1.2.0 + # via matplotlib +coverage[toml]==7.4.1 + # via + # codecov + # pytest-cov +cycler==0.12.1 + # via matplotlib +distlib==0.3.8 + # via virtualenv +docopt==0.6.2 + # via pytest-watch +filelock==3.13.1 + # via + # -c requirements-macos-cpu.txt + # virtualenv +flake8==7.0.0 + # via -r ../../requirements-dev.in +fonttools==4.48.1 + # via matplotlib +identify==2.5.35 + # via pre-commit +idna==3.6 + # via + # -c requirements-macos-cpu.txt + # requests +iniconfig==2.0.0 + # via pytest +itermplot==0.5 + # via -r ../../requirements-dev.in +kiwisolver==1.4.5 + # via matplotlib +matplotlib==3.8.2 + # via + # itermplot + # seaborn +mccabe==0.7.0 + # via flake8 +mypy-extensions==1.0.0 + # via + # -c requirements-macos-cpu.txt + # black +nodeenv==1.8.0 + # via pre-commit +numpy==1.26.4 + # via + # -c requirements-macos-cpu.txt + # contourpy + # itermplot + # matplotlib + # pandas + # seaborn +packaging==23.2 + # via + # black + # build + # matplotlib + # pytest +pandas==2.2.0 + # via + # -c requirements-macos-cpu.txt + # seaborn +pathspec==0.12.1 + # via black +pillow==10.2.0 + # via matplotlib +pip-tools==7.3.0 + # via -r ../../requirements-dev.in +platformdirs==4.2.0 + # via + # black + # virtualenv +pluggy==1.4.0 + # via pytest +pre-commit==3.6.2 + # via -r ../../requirements-dev.in +py==1.11.0 + # via pytest-forked +py-cpuinfo==9.0.0 + # via pytest-benchmark +pycodestyle==2.11.1 + # via flake8 +pyflakes==3.2.0 + # via flake8 +pyparsing==3.1.1 + # via matplotlib +pyproject-hooks==1.0.0 + # via build +pytest==8.0.0 + # via + # -r ../../requirements-dev.in + # pytest-benchmark + # pytest-cov + # pytest-forked + # pytest-instafail + # pytest-repeat + # pytest-watch +pytest-benchmark==4.0.0 + # via -r ../../requirements-dev.in +pytest-cov==4.1.0 + # via -r ../../requirements-dev.in +pytest-forked==1.6.0 + # via -r ../../requirements-dev.in +pytest-instafail==0.5.0 + # via -r ../../requirements-dev.in +pytest-repeat==0.9.3 + # via -r ../../requirements-dev.in +pytest-watch==4.2.0 + # via -r ../../requirements-dev.in +python-dateutil==2.8.2 + # via + # -c requirements-macos-cpu.txt + # matplotlib + # pandas +pytz==2024.1 + # via + # -c requirements-macos-cpu.txt + # pandas +pyyaml==6.0.1 + # via + # -c requirements-macos-cpu.txt + # pre-commit +requests==2.31.0 + # via + # -c requirements-macos-cpu.txt + # codecov +seaborn==0.13.2 + # via -r ../../requirements-dev.in +six==1.16.0 + # via + # -c requirements-macos-cpu.txt + # itermplot + # python-dateutil +tzdata==2023.4 + # via + # -c requirements-macos-cpu.txt + # pandas +urllib3==2.2.0 + # via + # -c requirements-macos-cpu.txt + # requests +virtualenv==20.25.1 + # via pre-commit +watchdog==4.0.0 + # via pytest-watch +wheel==0.42.0 + # via pip-tools + +# The following packages are considered to be unsafe in a requirements file: +# pip +# setuptools diff --git a/environments/macos-cpu/requirements-macos-cpu.txt b/environments/macos-cpu/requirements-macos-cpu.txt new file mode 100644 index 000000000..28f86c1d2 --- /dev/null +++ b/environments/macos-cpu/requirements-macos-cpu.txt @@ -0,0 +1,148 @@ +# +# This file is autogenerated by pip-compile with Python 3.11 +# by the following command: +# +# pip-compile --no-emit-index-url --output-file=requirements-macos-cpu.txt ../../requirements.in +# +asciitree==0.3.3 + # via zarr +astor==0.8.1 + # via -r ../../requirements.in +asttokens==2.4.1 + # via stack-data +attrs==23.2.0 + # via + # -r ../../requirements.in + # attrs-strict + # cattrs + # hypothesis +attrs-strict==1.0.1 + # via -r ../../requirements.in +cattrs==23.2.3 + # via -r ../../requirements.in +certifi==2024.2.2 + # via requests +charset-normalizer==3.3.2 + # via requests +decorator==5.1.1 + # via + # -r ../../requirements.in + # ipython +executing==2.0.1 + # via stack-data +fasteners==0.19 + # via zarr +filelock==3.13.1 + # via torch +frozendict==2.4.0 + # via -r ../../requirements.in +fsspec==2024.2.0 + # via torch +hypothesis==6.98.3 + # via -r ../../requirements.in +idna==3.6 + # via requests +ipython==8.21.0 + # via -r ../../requirements.in +jedi==0.19.1 + # via ipython +jinja2==3.1.3 + # via torch +llvmlite==0.42.0 + # via numba +markupsafe==2.1.5 + # via jinja2 +matplotlib-inline==0.1.6 + # via ipython +mpmath==1.3.0 + # via sympy +mypy-extensions==1.0.0 + # via typing-inspect +networkx==3.2.1 + # via + # -r ../../requirements.in + # torch +ninja==1.11.1.1 + # via -r ../../requirements.in +numba==0.59.0 + # via sparse +numcodecs==0.12.1 + # via zarr +numpy==1.26.4 + # via + # -r ../../requirements.in + # numba + # numcodecs + # pandas + # pyarrow + # scipy + # sparse + # zarr +pandas==2.2.0 + # via -r ../../requirements.in +parso==0.8.3 + # via jedi +pexpect==4.9.0 + # via ipython +pint==0.23 + # via -r ../../requirements.in +prompt-toolkit==3.0.43 + # via ipython +psutil==5.9.8 + # via -r ../../requirements.in +ptyprocess==0.7.0 + # via pexpect +pure-eval==0.2.2 + # via stack-data +pyarrow==15.0.0 + # via -r ../../requirements.in +pygments==2.17.2 + # via ipython +python-dateutil==2.8.2 + # via pandas +pytz==2024.1 + # via pandas +pyyaml==6.0.1 + # via -r ../../requirements.in +requests==2.31.0 + # via -r ../../requirements.in +scipy==1.12.0 + # via + # -r ../../requirements.in + # sparse +six==1.16.0 + # via + # asttokens + # python-dateutil +sortedcontainers==2.4.0 + # via hypothesis +sparse==0.15.1 + # via -r ../../requirements.in +stack-data==0.6.3 + # via ipython +sympy==1.12 + # via torch +toolz==0.12.1 + # via -r ../../requirements.in +torch==2.2.0 + # via -r ../../requirements.in +traitlets==5.14.1 + # via + # ipython + # matplotlib-inline +typing-extensions==4.9.0 + # via + # -r ../../requirements.in + # pint + # torch + # typing-inspect +typing-inspect==0.9.0 + # via -r ../../requirements.in +tzdata==2023.4 + # via pandas +urllib3==2.2.0 + # via requests +wcwidth==0.2.13 + # via prompt-toolkit +zarr==2.16.1 + # via -r ../../requirements.in diff --git a/requirements.in b/requirements.in index d06cf8a62..e1237082d 100644 --- a/requirements.in +++ b/requirements.in @@ -1,4 +1,4 @@ -torch == 2.6 +torch >= 2.5 numpy # NOTE we get numpy from conda to ensure compat with numba astor attrs != 22.2.0 # problem unpickling tmol/tests/data/rosetta_baseline/1ubq.scores.pickle in 22.2.0 (see https://github.com/python-attrs/attrs/pull/1085) @@ -18,6 +18,7 @@ requests scipy sparse # NOTE sparse>0.3.1 requires env var SPARSE_AUTO_DENSIFY=1 to be set toolz +torchshow typing_extensions typing_inspect ninja diff --git a/tmol/chemical/restypes.py b/tmol/chemical/restypes.py index 42f293be6..941b47081 100644 --- a/tmol/chemical/restypes.py +++ b/tmol/chemical/restypes.py @@ -23,6 +23,12 @@ ConnectionIndex = NewType("ConnectionIndex", int) BondCount = NewType("BondCount", int) +# As of cattr 24.1.0, more types must be explicitly registered in order to +# use cattr.structure. We use that here +cattr.register_structure_hook(numpy.dtype, lambda d, _: numpy.dtype(d)) +cattr.register_structure_hook(numpy.ndarray, lambda d, _: numpy.array(d)) + + # perhaps deserving of its own file UnresolvedAtomID = Tuple[AtomIndex, ConnectionIndex, BondCount] uaid_t = numpy.dtype( @@ -149,7 +155,6 @@ def _setup_bond_indices(self): [(ai, bi), (bi, ai)] for ai, bi in map(map(self.atom_to_idx.get), self.bonds) ) - bond_array = numpy.array(bondi, dtype=numpy.int32) bond_array.flags.writeable = False return bond_array @@ -472,7 +477,7 @@ def n_icoors(self): def _setup_icoors_index(self): return {icoor.name: i for i, icoor in enumerate(self.icoors)} - at_to_icoor_ind: numpy.array = attr.ib() + at_to_icoor_ind: numpy.ndarray = attr.ib() @at_to_icoor_ind.default def _setup_at_to_icoor_ind(self): diff --git a/tmol/database/scoring/omega_bbdep.py b/tmol/database/scoring/omega_bbdep.py index af5308c86..184a678d2 100644 --- a/tmol/database/scoring/omega_bbdep.py +++ b/tmol/database/scoring/omega_bbdep.py @@ -39,5 +39,5 @@ def from_file(cls, fname: str): OmegaBBDepTables, ] ): - print("safe globals: ", torch.serialization.get_safe_globals()) + # print("safe globals: ", torch.serialization.get_safe_globals()) return torch.load(fname) diff --git a/tmol/extern/moderngpu/context.hxx b/tmol/extern/moderngpu/context.hxx index 1de4d17c6..ea140e2da 100644 --- a/tmol/extern/moderngpu/context.hxx +++ b/tmol/extern/moderngpu/context.hxx @@ -23,8 +23,12 @@ inline std::string device_prop_string(cudaDeviceProp prop) { cudaError_t result = cudaMemGetInfo(&freeMem, &totalMem); if(cudaSuccess != result) throw cuda_exception_t(result); - double memBandwidth = (prop.memoryClockRate * 1000.0) * - (prop.memoryBusWidth / 8 * 2) / 1.0e9; + int memoryClockRate; + cudaDeviceGetAttribute(&memoryClockRate, cudaDevAttrMemoryClockRate, ordinal); + double memBandwidth = (memoryClockRate * 1000.0) * (prop.memoryBusWidth / 8 * 2) / 1.0e9; + + int clockRate; + cudaDeviceGetAttribute(&clockRate, cudaDevAttrClockRate, ordinal); std::string s = detail::stringprintf( "%s : %8.3lf Mhz (Ordinal %d)\n" @@ -32,10 +36,10 @@ inline std::string device_prop_string(cudaDeviceProp prop) { "FreeMem: %6dMB TotalMem: %6dMB %2d-bit pointers.\n" "Mem Clock: %8.3lf Mhz x %d bits (%5.1lf GB/s)\n" "ECC %s\n\n", - prop.name, prop.clockRate / 1000.0, ordinal, + prop.name, clockRate / 1000.0, ordinal, prop.multiProcessorCount, prop.major, prop.minor, (int)(freeMem / (1<< 20)), (int)(totalMem / (1<< 20)), 8 * sizeof(int*), - prop.memoryClockRate / 1000.0, prop.memoryBusWidth, memBandwidth, + memoryClockRate / 1000.0, prop.memoryBusWidth, memBandwidth, prop.ECCEnabled ? "Enabled" : "Disabled"); return s; } diff --git a/tmol/io/chain_deduction.py b/tmol/io/chain_deduction.py index de777570e..dd8563307 100644 --- a/tmol/io/chain_deduction.py +++ b/tmol/io/chain_deduction.py @@ -97,11 +97,16 @@ def chain_inds_for_pose_stack( n_components, labels = scipy.sparse.csgraph.connected_components( csr_bond_pairs, directed=False, return_labels=True ) + # print("n_components", n_components) + # print("labels", labels) labels = labels.reshape(n_poses, max_n_blocks) - n_ccs = numpy.amax(labels, axis=1) + 1 + n_ccs = numpy.amax(labels, axis=1) - numpy.amin(labels, axis=1) + 1 + # print("n_ccs", n_ccs) cc_offsets = exclusive_cumsum1d(n_ccs) + # print("cc_offsets", cc_offsets) labels = labels - cc_offsets.reshape(n_poses, 1) + # print("labels", labels) # now re-label the gap residues with a chain ind of -1 labels[unreal_blocks] = -1 diff --git a/tmol/kinematics/compiled/compiled.cuda.cu b/tmol/kinematics/compiled/compiled.cuda.cu index d7fae6b3f..3fffa89ed 100644 --- a/tmol/kinematics/compiled/compiled.cuda.cu +++ b/tmol/kinematics/compiled/compiled.cuda.cu @@ -224,6 +224,27 @@ struct ForwardKinDispatch { // reindexing function nvtx_range_push("dispatch::segscan"); auto k_reindex = [=] MGPU_DEVICE(int index, int seg, int rank) { + if (nodestart + index >= nodes.size(0) || nodestart + index < 0) { + printf( + "oops! nodestart %d + index %d vs nodes.size(0) %d\n", + nodestart, + index, + nodes.size(0)); + return *((HTRawBuffer*)HTs[0].data()); + } + if (nodes[nodestart + index] >= HTs.size(0) + || nodes[nodestart + index] < 0) { + printf( + "oops2! nodestart %d + index %d gives nodes[%d] = %d and " + "HTs.size() = %d\n", + nodestart, + index, + nodestart + index, + nodes[nodestart + index], + HTs.size(0)); + return *((HTRawBuffer*)HTs[0].data()); + } + assert(nodestart + index < nodes.size(0) && nodestart + index >= 0); assert( nodes[nodestart + index] < HTs.size(0) @@ -255,7 +276,28 @@ struct ForwardKinDispatch { // is) nvtx_range_push("dispatch::unindex"); auto k_unindex = [=] MGPU_DEVICE(int index) { + if (nodestart + index >= nodes.size(0) || nodestart + index < 0) { + printf( + "oops3! nodestart %d + index %d vs nodes.size(0) %d\n", + nodestart, + index, + nodes.size(0)); + return; // *((HTRawBuffer*)HTs[0].data()); + } assert(nodestart + index < nodes.size(0) && nodestart + index >= 0); + if (nodes[nodestart + index] >= HTs.size(0) + || nodes[nodestart + index] < 0) { + printf( + "oops4! nodestart %d + index %d gives nodes[%d] = %d and " + "HTs.size() = %d\n", + nodestart, + index, + nodestart + index, + nodes[nodestart + index], + HTs.size(0)); + return; // *((HTRawBuffer*)HTs[0].data()); + } + assert( nodes[nodestart + index] < HTs.size(0) && nodes[nodestart + index] >= 0); diff --git a/tmol/numeric/dihedrals.py b/tmol/numeric/dihedrals.py index 712d5b39e..dd3d69fb9 100644 --- a/tmol/numeric/dihedrals.py +++ b/tmol/numeric/dihedrals.py @@ -38,6 +38,6 @@ def coord_dihedrals( # angle between v and w in a plane is the torsion angle # v and w may not be normalized but that's fine since tan is y/x x = torch.einsum("ij,ij->i", (v, w)) - y = torch.einsum("ij,ij->i", (torch.cross(ubc, v), w)) + y = torch.einsum("ij,ij->i", (torch.linalg.cross(ubc, v), w)) return torch.atan2(y, x).type(torch.float) diff --git a/tmol/pack/compiled/annealer.hh b/tmol/pack/compiled/annealer.hh index 033d84c90..cc3431a8b 100644 --- a/tmol/pack/compiled/annealer.hh +++ b/tmol/pack/compiled/annealer.hh @@ -9,20 +9,46 @@ namespace tmol { namespace pack { namespace compiled { +template < + template class DeviceDispatch, + tmol::Device D, + typename Real, + typename Int> +struct InteractionGraphBuilder { + static auto f( + int const chunk_size, + TView n_rots_for_pose, + TView rot_offset_for_pose, + TView n_rots_for_block, + TView rot_offset_for_block, + TView pose_for_rot, + TView block_type_ind_for_rot, + TView block_ind_for_rot, + TView sparse_inds, + TView sparse_energies) + -> std::tuple< + TPack, + TPack, + TPack, + TPack >; +}; + template struct AnnealerDispatch { static auto forward( - TView nrotamers_for_res, - TView oneb_offsets, + int max_n_rotamers_per_pose, + TView pose_n_res, + TView pose_n_rotamers, + TView pose_rotamer_offset, + TView n_rotamers_for_res, + TView oneb_offsets, TView res_for_rot, - TView respair_nenergies, - TView chunk_size, - TView chunk_offset_offsets, - TView twob_offsets, - TView fine_chunk_offsets, + int32_t chunk_size, + TView chunk_offset_offsets, + TView chunk_offsets, TView energy1b, - TView energy2b, - int64_t seed) -> std::tuple, TPack >; + TView energy2b) + -> std::tuple, TPack >; }; } // namespace compiled diff --git a/tmol/pack/compiled/compiled.cpu.cpp b/tmol/pack/compiled/compiled.cpu.cpp index 3bb8196ae..702cdf29a 100644 --- a/tmol/pack/compiled/compiled.cpu.cpp +++ b/tmol/pack/compiled/compiled.cpu.cpp @@ -1,8 +1,11 @@ #include #include +#include + // ??? #include "annealer.hh" #include "simulated_annealing.hh" +#include "compiled.impl.hh" #include @@ -11,16 +14,18 @@ namespace pack { namespace compiled { template -void set_quench_order(TView quench_order) { +void set_quench_order( + TView quench_order, + int const n_rots, + int const pose_rotamer_offset) { // Create a random permutation of all the rotamers // and visit them in this order to ensure all of them // are seen during the quench step - int const nrots = quench_order.size(0); - for (int i = 0; i < nrots; ++i) { - quench_order[i] = i; + for (int i = 0; i < n_rots; ++i) { + quench_order[i] = i + pose_rotamer_offset; } - for (int i = 0; i <= nrots - 2; ++i) { - int j = i + rand() % (nrots - i); + for (int i = 0; i <= n_rots - 2; ++i) { + int j = i + rand() % (n_rots - i); // swap i and j; int jval = quench_order[j]; quench_order[j] = quench_order[i]; @@ -29,70 +34,85 @@ void set_quench_order(TView quench_order) { } template -struct AnnealerDispatch { - static auto forward( - TView nrotamers_for_res, - TView oneb_offsets, - TView res_for_rot, - TView respair_nenergies, - TView chunk_size_t, - TView chunk_offset_offsets, - TView twob_offsets, - TView fine_chunk_offsets, - TView energy1b, - TView energy2b, - int64_t seed) -> std::tuple, TPack > { - clock_t start = clock(); - - // No Frills Simulated Annealing! - int const nres = nrotamers_for_res.size(0); - int const nrotamers = res_for_rot.size(0); - int const chunk_size = chunk_size_t[0]; - - int ntraj = 1; - int const n_outer_iterations = 20; - int const n_inner_iterations_factor = 20; - int const n_inner_iterations = n_inner_iterations_factor * nrotamers; - - auto scores_t = TPack::zeros({1, ntraj}); - auto rotamer_assignments_t = TPack::zeros({ntraj, nres}); - auto best_rotamer_assignments_t = TPack::zeros({ntraj, nres}); - auto quench_order_t = TPack::zeros({nrotamers}); - // auto rotamer_attempts_t = TPack::zeros({nrotamers}); - - auto scores = scores_t.view; - auto rotamer_assignments = rotamer_assignments_t.view; - auto best_rotamer_assignments = rotamer_assignments_t.view; - auto quench_order = quench_order_t.view; - // auto rotamer_attempts = rotamer_attempts_t.view; - - float const high_temp = 100; - float const low_temp = 0.2; - - for (int traj = 0; traj < ntraj; ++traj) { +auto AnnealerDispatch::forward( + int max_n_rotamers_per_pose, + TView pose_n_res, // n-poses + TView n_rotamers_for_pose, // n-poses + TView rotamer_offset_for_pose, // n-poses + TView n_rotamers_for_res, // n-poses x max-n-res + TView oneb_offsets, // n-poses x max-n-res + TView res_for_rot, // n-rots + int32_t chunk_size, + TView + chunk_offset_offsets, // n-poses x max-n-res x max-n-res + TView chunk_offsets, // n-chunks-on-interacting-res + TView energy1b, + TView energy2b) + -> std::tuple, TPack > { + clock_t start = clock(); + + // No Frills Simulated Annealing! + int const n_poses = pose_n_res.size(0); + int const max_n_res = n_rotamers_for_res.size(1); + int const n_rotamers = res_for_rot.size(0); + + int n_traj = 1; + int const n_outer_iterations = 20; + int const n_inner_iterations_factor = 20; + + // move to inside for(n_poses) loop: + // int const n_inner_iterations = n_inner_iterations_factor * n_rotamers; + + auto scores_t = TPack::zeros({n_poses, n_traj}); + auto current_rotamer_assignments_t = + TPack::zeros({n_poses, n_traj, max_n_res}); + auto best_rotamer_assignments_t = + TPack::zeros({n_poses, n_traj, max_n_res}); + auto quench_order_t = TPack::zeros({n_rotamers}); + // auto rotamer_attempts_t = TPack::zeros({n_rotamers}); + + auto scores = scores_t.view; + auto current_rotamer_assignments = current_rotamer_assignments_t.view; + auto best_rotamer_assignments = best_rotamer_assignments_t.view; + auto quench_order = quench_order_t.view; + // auto rotamer_attempts = rotamer_attempts_t.view; + + float const high_temp = 100; + float const low_temp = 0.2; + + // std::cout << "Startng simulated annealing, first rng:" << rand() << + // std::endl; + + for (int pose = 0; pose < n_poses; ++pose) { + int const n_res = pose_n_res[pose]; + int const pose_n_rotamers = n_rotamers_for_pose[pose]; + int const pose_rotamer_offset = rotamer_offset_for_pose[pose]; + int const n_inner_iterations = n_inner_iterations_factor * pose_n_rotamers; + // std::cout << "Starting pose " << pose << " with " << pose_n_rotamers << " + // rotamers and offset " << pose_rotamer_offset << std::endl; + + for (int traj = 0; traj < n_traj; ++traj) { // std::cout << "Starting trajectory " << traj+1 << std::endl; - for (int i = 0; i < nres; ++i) { - int const i_nrots = nrotamers_for_res[i]; - rotamer_assignments[traj][i] = rand() % i_nrots; - best_rotamer_assignments[traj][i] = rand() % i_nrots; + // Initial assignment: assign a rotamer to every residue + for (int i = 0; i < max_n_res; ++i) { + int const i_n_rots = n_rotamers_for_res[pose][i]; + int rand_rot = rand() % i_n_rots; + current_rotamer_assignments[pose][traj][i] = rand_rot; + best_rotamer_assignments[pose][traj][i] = rand_rot; } // std::cout << "Assigned random rotamers to all residues" << std::endl; float temperature = high_temp; double best_energy = total_energy_for_assignment( - nrotamers_for_res, - oneb_offsets, - res_for_rot, - respair_nenergies, - chunk_size_t, - chunk_offset_offsets, - twob_offsets, - fine_chunk_offsets, + n_rotamers_for_res[pose], + oneb_offsets[pose], + chunk_size, + chunk_offset_offsets[pose], + chunk_offsets, energy1b, energy2b, - rotamer_assignments, - traj); + current_rotamer_assignments[pose][traj]); double current_total_energy = best_energy; int naccepts = 0; for (int i = 0; i < n_outer_iterations; ++i) { @@ -101,41 +121,42 @@ struct AnnealerDispatch { if (i == n_outer_iterations - 1) { quench = true; temperature = 0; - for (int j = 0; j < nres; ++j) { - rotamer_assignments[traj][j] = best_rotamer_assignments[traj][j]; + for (int j = 0; j < n_res; ++j) { + current_rotamer_assignments[pose][traj][j] = + best_rotamer_assignments[pose][traj][j]; } current_total_energy = total_energy_for_assignment( - nrotamers_for_res, - oneb_offsets, - res_for_rot, - respair_nenergies, - chunk_size_t, - chunk_offset_offsets, - twob_offsets, - fine_chunk_offsets, + n_rotamers_for_res[pose], + oneb_offsets[pose], + chunk_size, + chunk_offset_offsets[pose], + chunk_offsets, energy1b, energy2b, - rotamer_assignments, - traj); + current_rotamer_assignments[pose][traj]); } for (int j = 0; j < n_inner_iterations; ++j) { int global_ran_rot; if (quench) { - if (j % nrotamers == 0) { - set_quench_order(quench_order); + if (j % pose_n_rotamers == 0) { + // std::cout << "setting quench order..." << std::flush; + set_quench_order( + quench_order, pose_n_rotamers, pose_rotamer_offset); + // std::cout << "...done" << std::endl; } - global_ran_rot = quench_order[j % nrotamers]; + global_ran_rot = quench_order[j % pose_n_rotamers]; } else { - global_ran_rot = rand() % nrotamers; + global_ran_rot = rand() % pose_n_rotamers + pose_rotamer_offset; } // ++rotamer_attempts[ran_rot]; int const ran_res = res_for_rot[global_ran_rot]; - int const local_prev_rot = rotamer_assignments[traj][ran_res]; - int const ran_res_nrots = nrotamers_for_res[ran_res]; - int const ran_res_nchunks = (ran_res_nrots - 1) / chunk_size + 1; - int const ran_res_offset = oneb_offsets[ran_res]; + int const local_prev_rot = + current_rotamer_assignments[pose][traj][ran_res]; + int const ran_res_n_rots = n_rotamers_for_res[pose][ran_res]; + int const ran_res_n_chunks = (ran_res_n_rots - 1) / chunk_size + 1; + int const ran_res_offset = oneb_offsets[pose][ran_res]; int const local_ran_rot = global_ran_rot - ran_res_offset; int const ran_rot_chunk = local_ran_rot / chunk_size; int const prev_rot_chunk = local_prev_rot / chunk_size; @@ -144,9 +165,9 @@ struct AnnealerDispatch { int const prev_rot_in_chunk = local_prev_rot - chunk_size * prev_rot_chunk; int const ran_rot_chunk_size = - std::min(chunk_size, ran_res_nrots - chunk_size * ran_rot_chunk); - int const prev_rot_chunk_size = - std::min(chunk_size, ran_res_nrots - chunk_size * prev_rot_chunk); + std::min(chunk_size, ran_res_n_rots - chunk_size * ran_rot_chunk); + int const prev_rot_chunk_size = std::min( + chunk_size, ran_res_n_rots - chunk_size * prev_rot_chunk); int const global_prev_rot = local_prev_rot + ran_res_offset; double new_e = energy1b[global_ran_rot]; @@ -155,38 +176,38 @@ struct AnnealerDispatch { // Temp: iterate across all residues instead of just the // neighbors of ran_rot_res - for (int k = 0; k < nres; ++k) { + for (int k = 0; k < n_res; ++k) { if (k == ran_res) continue; - if (respair_nenergies[ran_res][k] == 0) continue; - int const local_k_rot = rotamer_assignments[traj][k]; - int const k_nrots = nrotamers_for_res[k]; - int const kres_nchunks = (k_nrots - 1) / chunk_size + 1; + int64_t const k_ran_chunk_offset_offset = + chunk_offset_offsets[pose][k][ran_res]; + if (k_ran_chunk_offset_offset == -1) { + // then neither prev_rot nor ran_rot interact with the rotamer at + // k + continue; + } + int const local_k_rot = current_rotamer_assignments[pose][traj][k]; + int const k_n_rots = n_rotamers_for_res[pose][k]; + int const kres_n_chunks = (k_n_rots - 1) / chunk_size + 1; int const krot_chunk = local_k_rot / chunk_size; int const krot_in_chunk = local_k_rot - krot_chunk * chunk_size; - int const k_ran_chunk_offset_offset = - chunk_offset_offsets[k][ran_res]; - int const krot_ranrot_chunk_offset = fine_chunk_offsets - [k_ran_chunk_offset_offset + krot_chunk * ran_res_nchunks + int64_t const krot_ranrot_chunk_offset = chunk_offsets + [k_ran_chunk_offset_offset + krot_chunk * ran_res_n_chunks + ran_rot_chunk]; - int const krot_prevrot_chunk_offset = fine_chunk_offsets - [k_ran_chunk_offset_offset + krot_chunk * ran_res_nchunks + int64_t const krot_prevrot_chunk_offset = chunk_offsets + [k_ran_chunk_offset_offset + krot_chunk * ran_res_n_chunks + prev_rot_chunk]; - int64_t const k_ran_offset = twob_offsets[k][ran_res]; - - // new_e += energy2b[ran_k_offset + kres_nrots * local_ran_rot + - // local_k_rot]; double k_new_e = 0; double k_prev_e = 0; if (krot_ranrot_chunk_offset >= 0) { k_new_e = energy2b - [k_ran_offset + krot_ranrot_chunk_offset - + krot_in_chunk * ran_rot_chunk_size + ran_rot_in_chunk]; + [krot_ranrot_chunk_offset + krot_in_chunk * ran_rot_chunk_size + + ran_rot_in_chunk]; } if (krot_prevrot_chunk_offset >= 0) { k_prev_e = energy2b - [k_ran_offset + krot_prevrot_chunk_offset + [krot_prevrot_chunk_offset + krot_in_chunk * prev_rot_chunk_size + prev_rot_in_chunk]; } deltaE += k_new_e - k_prev_e; @@ -198,30 +219,26 @@ struct AnnealerDispatch { if (pass_metropolis( temperature, uniform_random, deltaE, prev_e, quench)) { - rotamer_assignments[traj][ran_res] = local_ran_rot; + current_rotamer_assignments[pose][traj][ran_res] = local_ran_rot; current_total_energy += deltaE; ++naccepts; if (naccepts > 1000) { naccepts = 0; float new_current_total_energy = total_energy_for_assignment( - nrotamers_for_res, - oneb_offsets, - res_for_rot, - respair_nenergies, - chunk_size_t, - chunk_offset_offsets, - twob_offsets, - fine_chunk_offsets, + n_rotamers_for_res[pose], + oneb_offsets[pose], + chunk_size, + chunk_offset_offsets[pose], + chunk_offsets, energy1b, energy2b, - rotamer_assignments, - traj); + current_rotamer_assignments[pose][traj]); current_total_energy = new_current_total_energy; } if (current_total_energy < best_energy) { - for (int k = 0; k < nres; ++k) { - best_rotamer_assignments[traj][k] = - rotamer_assignments[traj][k]; + for (int k = 0; k < n_res; ++k) { + best_rotamer_assignments[pose][traj][k] = + current_rotamer_assignments[pose][traj][k]; } best_energy = current_total_energy; } @@ -230,7 +247,7 @@ struct AnnealerDispatch { } // end inner loop // std::cout << "temperature " << temperature << " energy " << - // total_energy_for_assignment(nrotamers_for_res, oneb_offsets, + // total_energy_for_assignment(n_rotamers_for_res, oneb_offsets, // res_for_rot, nenergies, twob_offsets, energy1b, energy2b, // rotamer_assignments, traj) << // std::endl; @@ -242,51 +259,59 @@ struct AnnealerDispatch { } // end outer loop - scores[0][traj] = total_energy_for_assignment( - nrotamers_for_res, - oneb_offsets, - res_for_rot, - respair_nenergies, - chunk_size_t, - chunk_offset_offsets, - twob_offsets, - fine_chunk_offsets, + // std::cout << "calc total energy" << std::flush; + scores[pose][traj] = total_energy_for_assignment( + n_rotamers_for_res[pose], + oneb_offsets[pose], + chunk_size, + chunk_offset_offsets[pose], + chunk_offsets, energy1b, energy2b, - rotamer_assignments, - traj); - // std::cout << "Traj " << traj << " with score " << scores[traj] << - // std::endl; + best_rotamer_assignments[pose][traj]); + // std::cout << "Traj " << traj << " for pose " << pose << " with score " + // << scores[pose][traj] << std::endl; } // end trajectory loop + } // end pose loop - // find the stdev of rotamer attempts - // float variance = 0; - // float mean = n_outer_iterations * n_inner_iterations_factor; - // for (int i = 0; i < nrotamers; ++i) { - // int iattempts = rotamer_attempts[i]; - // variance += (iattempts - mean)*(iattempts - mean); - // } - // variance /= nrotamers; - // float sdev = std::sqrt(variance); - // std::cout << "attempts variance" << variance << std::endl; - // for (int i = 0; i < nrotamers; ++i) { - // int iattempts = rotamer_attempts[i]; - // if (std::abs(iattempts - mean) > 2*sdev) { - // std::cout << "Rotamer " << i << " on res " << res_for_rot[i] << " - // attempted " << iattempts << " times." << std::endl; - // } - // } - clock_t stop = clock(); - std::cout << "CPU simulated annealing in " - << ((double)stop - start) / CLOCKS_PER_SEC << " seconds" - << std::endl; - - return {scores_t, rotamer_assignments_t}; - } -}; + // find the stdev of rotamer attempts + // float variance = 0; + // float mean = n_outer_iterations * n_inner_iterations_factor; + // for (int i = 0; i < n_rotamers; ++i) { + // int iattempts = rotamer_attempts[i]; + // variance += (iattempts - mean)*(iattempts - mean); + // } + // variance /= n_rotamers; + // float sdev = std::sqrt(variance); + // std::cout << "attempts variance" << variance << std::endl; + // for (int i = 0; i < n_rotamers; ++i) { + // int iattempts = rotamer_attempts[i]; + // if (std::abs(iattempts - mean) > 2*sdev) { + // std::cout << "Rotamer " << i << " on res " << res_for_rot[i] << " + // attempted " << iattempts << " times." << std::endl; + // } + // } + clock_t stop = clock(); + std::cout << "CPU simulated annealing in " + << ((double)stop - start) / CLOCKS_PER_SEC << " seconds" + << std::endl; + + return {scores_t, best_rotamer_assignments_t}; +} template struct AnnealerDispatch; +template struct InteractionGraphBuilder< + score::common::DeviceOperations, + tmol::Device::CPU, + float, + int64_t>; +template struct InteractionGraphBuilder< + score::common::DeviceOperations, + tmol::Device::CPU, + double, + int64_t>; + } // namespace compiled } // namespace pack } // namespace tmol diff --git a/tmol/pack/compiled/compiled.cuda.cu b/tmol/pack/compiled/compiled.cuda.cu index be00f77da..2ce55c487 100644 --- a/tmol/pack/compiled/compiled.cuda.cu +++ b/tmol/pack/compiled/compiled.cuda.cu @@ -1,9 +1,12 @@ #include #include -/*#include -#include +#include +#include + +/*#include #include */ +#include #include #include @@ -12,7 +15,8 @@ #include #include -#include +// #include +#include #include #include @@ -22,6 +26,8 @@ #include +#include "compiled.impl.hh" + // Stolen from torch, v1.0.0 // Expose part of the torch library that otherwise is // not part of the API. @@ -43,16 +49,16 @@ namespace tmol { namespace pack { namespace compiled { -template +template MGPU_DEVICE __inline__ T reduce_shfl_and_broadcast( - cooperative_groups::thread_block_tile g, T val, op_t op) { + cooperative_groups::thread_block_tile g, T val, op_t op) { // T val_orig(val); - // mgpu::shfl_reduce_t reducer; + // mgpu::shfl_reduce_t reducer; // val = reducer. template reduce( - // g.thread_rank(), val, nthreads, op); + // g.thread_rank(), val, n_threads, op); // // T hand_rolled_val(val_orig); - for (unsigned int i = nthreads / 2; i > 0; i /= 2) { + for (unsigned int i = n_threads / 2; i > 0; i /= 2) { T const shfl_val = g.shfl_down(val, i); if (g.thread_rank() < 32 - i) { val = op(val, shfl_val); @@ -74,33 +80,40 @@ MGPU_DEVICE __inline__ T reduce_shfl_and_broadcast( template struct InteractionGraph { public: - TView nrotamers_for_res_; - TView oneb_offsets_; + int max_n_rotamers_per_pose_; + TView pose_n_res_; + TView pose_n_rotamers_; + TView pose_rotamer_offset_; + TView n_rotamers_for_res_; + TView oneb_offsets_; TView res_for_rot_; - TView respair_nenergies_; - TView chunk_size_; - TView chunk_offset_offsets_; - TView twob_offsets_; - TView fine_chunk_offsets_; + int32_t chunk_size_; + TView chunk_offset_offsets_; + TView chunk_offsets_; TView energy1b_; TView energy2b_; - int nres_cpu() const { return nrotamers_for_res_.size(0); } - int nrotamers_cpu() const { return res_for_rot_.size(0); } + int n_poses_cpu() const { return pose_n_res_.size(0); } + int max_n_res_cpu() const { return n_rotamers_for_res_.size(1); } + int n_rotamers_total_cpu() const { return res_for_rot_.size(0); } + int max_n_rotamers_per_pose_cpu() const { return max_n_rotamers_per_pose_; } + + MGPU_DEVICE + int n_poses() const { return pose_n_res_.size(0); } MGPU_DEVICE - int nres() const { return nrotamers_for_res_.size(0); } + int n_res(int pose) const { return pose_n_res_[pose]; } MGPU_DEVICE - int nrotamers() const { return res_for_rot_.size(0); } + int n_rotamers(int pose) const { return pose_n_rotamers_[pose]; } MGPU_DEVICE - TView const& nrotamers_for_res() const { - return nrotamers_for_res_; + TView const& n_rotamers_for_res() const { + return n_rotamers_for_res_; } MGPU_DEVICE - TView const& oneb_offsets() const { return oneb_offsets_; } + TView const& oneb_offsets() const { return oneb_offsets_; } MGPU_DEVICE TView const& res_for_rot() const { return res_for_rot_; } @@ -111,101 +124,113 @@ struct InteractionGraph { // Return the 1b + 2b energy for a substited rotamer at a residue MGPU_DEVICE Real rotamer_energy_against_background( + int pose, int sub_res, - int sub_res_nrots, + int sub_res_n_rots, int local_sub_rot, int global_sub_rot, - TensorAccessor rotamer_assignments, + TensorAccessor rotamer_assignment, bool this_thread_active) const { float new_e = 1e30; if (this_thread_active) { new_e = energy1b_[global_sub_rot]; } - int sub_rot_chunk = local_sub_rot / chunk_size_[0]; - int sub_rot_in_chunk = local_sub_rot - sub_rot_chunk * chunk_size_[0]; - int sub_res_nchunks = (sub_res_nrots - 1) / chunk_size_[0] + 1; + int sub_rot_chunk = local_sub_rot / chunk_size_; + int sub_rot_in_chunk = local_sub_rot - sub_rot_chunk * chunk_size_; + int sub_res_n_chunks = (sub_res_n_rots - 1) / chunk_size_ + 1; int sub_rot_chunk_size = - min(chunk_size_[0], sub_res_nrots - chunk_size_[0] * sub_rot_chunk); + min(chunk_size_, sub_res_n_rots - chunk_size_ * sub_rot_chunk); // Temp: iterate across all residues instead of just the // neighbors of ran_rot_res if (this_thread_active) { - for (int k = 0; k < nres(); ++k) { - if (k == sub_res || respair_nenergies_[sub_res][k] == 0) { + for (int k = 0; k < n_res(pose); ++k) { + if (k == sub_res) { + continue; + } + int const local_k_rot = rotamer_assignment[k]; + int const k_chunk = local_k_rot / chunk_size_; + int64_t const k_sub_chunk_offset_offset = + chunk_offset_offsets_[pose][k][sub_res]; + if (k_sub_chunk_offset_offset == -1) { continue; } - int const local_k_rot = rotamer_assignments[k]; - int const k_chunk = local_k_rot / chunk_size_[0]; - int const k_sub_chunk_offset_offset = chunk_offset_offsets_[k][sub_res]; - int const k_in_chunk = local_k_rot - k_chunk * chunk_size_[0]; - int const k_res_nrots = nrotamers_for_res_[k]; + int const k_in_chunk = local_k_rot - k_chunk * chunk_size_; + int const k_res_n_rots = n_rotamers_for_res_[pose][k]; int const k_chunk_size = - min(chunk_size_[0], k_res_nrots - chunk_size_[0] * k_chunk); - int const k_sub_chunk_start = fine_chunk_offsets_ - [k_sub_chunk_offset_offset + k_chunk * sub_res_nchunks + min(chunk_size_, k_res_n_rots - chunk_size_ * k_chunk); + int64_t const k_sub_chunk_start = chunk_offsets_ + [k_sub_chunk_offset_offset + k_chunk * sub_res_n_chunks + sub_rot_chunk]; - if (k_sub_chunk_start < 0) { + if (k_sub_chunk_start == -1) { continue; } - // printf("%d inds %d %d, %d, %d, %d * %d, %d\n", threadIdx.x, sub_res, - // k, - // twob_offsets_[k][sub_res], k_sub_chunk_start, sub_rot_chunk_size, - // k_in_chunk, sub_rot_in_chunk); new_e += energy2b_ - [twob_offsets_[k][sub_res] + k_sub_chunk_start - + sub_rot_chunk_size * k_in_chunk + sub_rot_in_chunk]; + [k_sub_chunk_start + sub_rot_chunk_size * k_in_chunk + + sub_rot_in_chunk]; + // printf("%d inds %d %d, %lld, %lld, %d * %d + %d = %d; e2b[%lld] = %f + // \n", + // threadIdx.x, sub_res, k, + // chunk_offset_offsets_[pose][k][sub_res], + // k_sub_chunk_start, sub_rot_chunk_size, k_in_chunk, + // sub_rot_in_chunk, sub_rot_chunk_size * k_in_chunk + + // sub_rot_in_chunk, k_sub_chunk_start + sub_rot_chunk_size * + // k_in_chunk + sub_rot_in_chunk, new_e + // ); } } return new_e; } - template + template MGPU_DEVICE Real total_energy_for_assignment_parallel( - cooperative_groups::thread_block_tile g, + int pose, + cooperative_groups::thread_block_tile g, TensorAccessor rotamer_assignment) const { Real totalE = 0; - int const nres = nrotamers_for_res_.size(0); - for (int i = g.thread_rank(); i < nres; i += nthreads) { + int const n_res = pose_n_res_[pose]; + for (int i = g.thread_rank(); i < n_res; i += n_threads) { int const irot_local = rotamer_assignment[i]; - int const irot_global = irot_local + oneb_offsets_[i]; + int const irot_global = irot_local + oneb_offsets_[pose][i]; totalE += energy1b_[irot_global]; } - for (int i = g.thread_rank(); i < nres; i += nthreads) { + // TO DO: iterate across upper-triangle indices only + for (int i = g.thread_rank(); i < n_res; i += n_threads) { int const irot_local = rotamer_assignment[i]; - int const irot_chunk = irot_local / chunk_size_[0]; - int const irot_in_chunk = irot_local - chunk_size_[0] * irot_chunk; - int const ires_nrots = nrotamers_for_res_[i]; - int const ires_nchunks = (ires_nrots - 1) / chunk_size_[0] + 1; + int const irot_chunk = irot_local / chunk_size_; + int const irot_in_chunk = irot_local - chunk_size_ * irot_chunk; + int const ires_n_rots = n_rotamers_for_res_[pose][i]; + int const ires_n_chunks = (ires_n_rots - 1) / chunk_size_ + 1; int const irot_chunk_size = - min(chunk_size_[0], ires_nrots - chunk_size_[0] * irot_chunk); + min(chunk_size_, ires_n_rots - chunk_size_ * irot_chunk); - for (int j = i + 1; j < nres; ++j) { + for (int j = i + 1; j < n_res; ++j) { int const jrot_local = rotamer_assignment[j]; - if (respair_nenergies_[i][j] == 0) { + int64_t const ij_chunk_offset_offset = + chunk_offset_offsets_[pose][i][j]; + if (ij_chunk_offset_offset == -1) { continue; } - int const jrot_chunk = jrot_local / chunk_size_[0]; - int const jrot_in_chunk = jrot_local - chunk_size_[0] * jrot_chunk; - int const ij_chunk_offset_offset = chunk_offset_offsets_[i][j]; + int const jrot_chunk = jrot_local / chunk_size_; + int const jrot_in_chunk = jrot_local - chunk_size_ * jrot_chunk; - int const jres_nrots = nrotamers_for_res_[j]; - int const jres_nchunks = (jres_nrots - 1) / chunk_size_[0] + 1; + int const jres_n_rots = n_rotamers_for_res_[pose][j]; + int const jres_n_chunks = (jres_n_rots - 1) / chunk_size_ + 1; int const jrot_chunk_size = - min(chunk_size_[0], jres_nrots - chunk_size_[0] * jrot_chunk); - int const ij_chunk_offset = fine_chunk_offsets_ - [ij_chunk_offset_offset + irot_chunk * jres_nchunks + jrot_chunk]; - if (ij_chunk_offset < 0) { + min(chunk_size_, jres_n_rots - chunk_size_ * jrot_chunk); + int64_t const ij_chunk_offset = chunk_offsets_ + [ij_chunk_offset_offset + irot_chunk * jres_n_chunks + jrot_chunk]; + if (ij_chunk_offset == -1) { continue; } float ij_energy = energy2b_ - [twob_offsets_[i][j] + ij_chunk_offset - + jrot_chunk_size * irot_in_chunk + jrot_in_chunk]; + [ij_chunk_offset + jrot_chunk_size * irot_in_chunk + jrot_in_chunk]; totalE += ij_energy; } } @@ -224,10 +249,10 @@ int curand_in_range(curandStatePhilox4_32_10_t* state, int n) { return int(curand_uniform(state) * n) % n; } -template +template MGPU_DEVICE __inline__ T exclusive_scan_shfl( - cooperative_groups::thread_block_tile g, T val, F f) { - for (unsigned int i = 1; i <= nthreads; i *= 2) { + cooperative_groups::thread_block_tile g, T val, F f) { + for (unsigned int i = 1; i <= n_threads; i *= 2) { T const shfl_val = g.shfl_up(val, i); if (i < g.thread_rank()) { val = f(shfl_val, val); @@ -240,10 +265,10 @@ MGPU_DEVICE __inline__ T exclusive_scan_shfl( return val; } -template +template MGPU_DEVICE __inline__ T inclusive_scan_shfl( - cooperative_groups::thread_block_tile g, T val, F f) { - for (unsigned int i = 1; i <= nthreads; i *= 2) { + cooperative_groups::thread_block_tile g, T val, F f) { + for (unsigned int i = 1; i <= n_threads; i *= 2) { T const shfl_val = g.shfl_up(val, i); if (g.thread_rank() >= i) { val = f(shfl_val, val); @@ -254,16 +279,19 @@ MGPU_DEVICE __inline__ T inclusive_scan_shfl( template MGPU_DEVICE void set_quench_order( - TensorAccessor quench_order, curandStatePhilox4_32_10_t* state) { + TensorAccessor quench_order, + int n_rotamers, + int rotamer_offset, + curandStatePhilox4_32_10_t* state) { // Create a random permutation of all the rotamers // and visit them in this order to ensure all of them // are seen during the quench step - int const nrots = quench_order.size(0); - for (int i = 0; i < nrots; ++i) { - quench_order[i] = i; + + for (int i = 0; i < n_rotamers; ++i) { + quench_order[i] = i + rotamer_offset; } - for (int i = 0; i <= nrots - 2; ++i) { - int rand_offset = curand_in_range(state, nrots - i); + for (int i = 0; i <= n_rotamers - 2; ++i) { + int rand_offset = curand_in_range(state, n_rotamers - i); int j = i + rand_offset; // swap i and j; int jval = quench_order[j]; @@ -274,24 +302,25 @@ MGPU_DEVICE void set_quench_order( template MGPU_DEVICE int set_quench_32_order( - TView nrotamers_for_res, - TView oneb_offsets, + int n_residues, + TensorAccessor n_rotamers_for_res, + TensorAccessor oneb_offsets, TensorAccessor quench_order, curandStatePhilox4_32_10_t* state) { // Create a random permutation of all the rotamers // and visit them in this order to ensure all of them // are seen during the quench step - int const nresidues = nrotamers_for_res.size(0); - int const nrots = quench_order.size(0); + // int const n_residues = n_rotamers_for_res.size(0); + // int const n_rots = quench_order.size(0); int count_n_quench_rots = 0; - for (int i = 0; i < nresidues; ++i) { - int const i_nrots = nrotamers_for_res[i]; + for (int i = 0; i < n_residues; ++i) { + int const i_n_rots = n_rotamers_for_res[i]; int const i_offset = oneb_offsets[i]; - int const i_nquench_rots = (i_nrots - 1) / 31 + 1; - for (int j = 0; j < i_nquench_rots; j++) { + int const i_n_quench_rots = (i_n_rots - 1) / 31 + 1; + for (int j = 0; j < i_n_quench_rots; j++) { quench_order[count_n_quench_rots + j] = i_offset + 31 * j; } - count_n_quench_rots += i_nquench_rots; + count_n_quench_rots += i_n_quench_rots; } for (int i = 0; i <= count_n_quench_rots - 2; ++i) { int rand_offset = curand_in_range(state, count_n_quench_rots - i); @@ -304,14 +333,15 @@ MGPU_DEVICE int set_quench_32_order( return count_n_quench_rots; } -template +template MGPU_DEVICE float warp_wide_sim_annealing( + int pose, + int traj_id, // debugging purposes only curandStatePhilox4_32_10_t* state, - cooperative_groups::thread_block_tile g, + cooperative_groups::thread_block_tile g, InteractionGraph ig, - int warp_id, - TensorAccessor rotamer_assignments, - TensorAccessor best_rotamer_assignments, + TensorAccessor current_rotamer_assignment, + TensorAccessor best_rotamer_assignment, TensorAccessor quench_order, float hi_temp, float lo_temp, @@ -320,21 +350,22 @@ MGPU_DEVICE float warp_wide_sim_annealing( int n_quench_iterations, bool quench_on_last_iteration, bool quench_lite) { - int const nres = ig.nres(); - int const nrotamers = ig.nrotamers(); + int const n_res = ig.n_res(pose); + int const n_rotamers = ig.n_rotamers(pose); + int const pose_rotamer_offset = ig.pose_rotamer_offset_[pose]; float temperature = hi_temp; - float best_energy = - ig.total_energy_for_assignment_parallel(g, rotamer_assignments); + float best_energy = ig.total_energy_for_assignment_parallel( + pose, g, current_rotamer_assignment); float current_total_energy = best_energy; - int ntrials = 0; + int n_trials = 0; for (int i = 0; i < n_outer_iterations; ++i) { - // if (g.thread_rank() == 0) { - // printf("top of outer loop %d currentE %f bestE %f temp %f\n", i, - // current_total_energy, best_energy, temperature); + // if (g.thread_rank() == 0 && traj_id == 0) { + // printf("p %d t %d top of outer loop %d currentE %f bestE %f temp %f\n", + // pose, traj_id, i, current_total_energy, best_energy, temperature); // } bool quench = false; - int quench_period = nrotamers; + int quench_period = n_rotamers; int i_n_inner_iterations = n_inner_iterations; if (i == n_outer_iterations - 1 && quench_on_last_iteration) { @@ -343,11 +374,11 @@ MGPU_DEVICE float warp_wide_sim_annealing( temperature = 1e-20; // recover the lowest energy rotamer assignment encountered // and begin quench from there - for (int j = g.thread_rank(); j < nres; j += 32) { - rotamer_assignments[j] = best_rotamer_assignments[j]; + for (int j = g.thread_rank(); j < n_res; j += 32) { + current_rotamer_assignment[j] = best_rotamer_assignment[j]; } - current_total_energy = - ig.total_energy_for_assignment_parallel(g, rotamer_assignments); + current_total_energy = ig.total_energy_for_assignment_parallel( + pose, g, current_rotamer_assignment); } for (int j = 0; j < i_n_inner_iterations; ++j) { @@ -358,16 +389,18 @@ MGPU_DEVICE float warp_wide_sim_annealing( if (j % quench_period == 0) { if (quench_lite) { quench_period = set_quench_32_order( - ig.nrotamers_for_res(), - ig.oneb_offsets(), + n_res, + ig.n_rotamers_for_res_[pose], + ig.oneb_offsets_[pose], quench_order, state); i_n_inner_iterations = quench_period; } else { - set_quench_order(quench_order, state); + set_quench_order( + quench_order, n_rotamers, pose_rotamer_offset, state); } } - ran_rot = quench_order[j % nrotamers]; + ran_rot = quench_order[j % n_rotamers]; } ran_rot = g.shfl(ran_rot, 0); if (j % quench_period == 0 && quench_lite) { @@ -376,25 +409,30 @@ MGPU_DEVICE float warp_wide_sim_annealing( accept_prob = .5; } else { if (g.thread_rank() == 0) { + // TO DO: Make more efficient by having each thread call curand and + // then broadcast to other threads their rngs % 32; also, use all 4 + // rands and not just the first two. float4 four_rands = curand_uniform4(state); - ran_rot = int(four_rands.x * nrotamers) % nrotamers; + ran_rot = + int(four_rands.x * n_rotamers) % n_rotamers + pose_rotamer_offset; accept_prob = four_rands.y; } ran_rot = g.shfl(ran_rot, 0); accept_prob = g.shfl(accept_prob, 0); } int const ran_res = ig.res_for_rot()[ran_rot]; - int const local_prev_rot = rotamer_assignments[ran_res]; - int const ran_res_nrots = ig.nrotamers_for_res()[ran_res]; - int const ran_res_rotamer_offset = ig.oneb_offsets()[ran_res]; + int const local_prev_rot = current_rotamer_assignment[ran_res]; + int const ran_res_n_rots = ig.n_rotamers_for_res()[pose][ran_res]; + int const ran_res_rotamer_offset = ig.oneb_offsets()[pose][ran_res]; bool prev_rot_in_range = false; int thread_w_prev_rot = 0; { // scope int const local_ran_rot_orig = ran_rot - ran_res_rotamer_offset; - int const local_prev_rot_wrapped = local_ran_rot_orig < local_prev_rot - ? local_prev_rot - : local_prev_rot + ran_res_nrots; + int const local_prev_rot_wrapped = + local_ran_rot_orig < local_prev_rot + ? local_prev_rot + : local_prev_rot + ran_res_n_rots; prev_rot_in_range = local_ran_rot_orig + 32 > local_prev_rot_wrapped; thread_w_prev_rot = prev_rot_in_range ? local_prev_rot_wrapped - local_ran_rot_orig : 0; @@ -402,25 +440,26 @@ MGPU_DEVICE float warp_wide_sim_annealing( int const local_ran_rot = prev_rot_in_range ? ((ran_rot - ran_res_rotamer_offset + g.thread_rank()) - % ran_res_nrots) + % ran_res_n_rots) : (g.thread_rank() == 0 ? local_prev_rot : (ran_rot - ran_res_rotamer_offset + g.thread_rank() - 1) - % ran_res_nrots); + % ran_res_n_rots); ran_rot = local_ran_rot + ran_res_rotamer_offset; // If there are fewer rotamers on this residue than there are threads // active in the warp, do not wrap and consider a rotamer more than once - bool const this_thread_active = ran_res_nrots > g.thread_rank(); + bool const this_thread_active = ran_res_n_rots > g.thread_rank(); bool const this_thread_last_active = - ran_res_nrots == g.thread_rank() || g.thread_rank() == 32 - 1; + ran_res_n_rots == g.thread_rank() || g.thread_rank() == 32 - 1; float new_e = ig.rotamer_energy_against_background( + pose, ran_res, - ran_res_nrots, + ran_res_n_rots, local_ran_rot, ran_rot, - rotamer_assignments, + current_rotamer_assignment, this_thread_active); // if (g.thread_rank() == 0) { @@ -428,7 +467,10 @@ MGPU_DEVICE float warp_wide_sim_annealing( // } float const min_e = reduce_shfl_and_broadcast(g, new_e, mgpu::minimum_t()); - // printf("thread %d min_e %f\n", thread_id, min_e); + // if (traj_id == 0) { + // printf("thread %d new_e %f min_e %f\n", g.thread_rank(), new_e, + // min_e); + // } float myexp = expf(-1 * (new_e - min_e) / temperature); // printf("thread %d myexp %f\n", thread_id, myexp); // if (g.thread_rank() == 0) { @@ -481,8 +523,10 @@ MGPU_DEVICE float warp_wide_sim_annealing( bool new_best = false; if (accept) { float deltaE = new_e - prev_e; - // printf("deltaE: %f (%f - %f)\n", deltaE, new_e, prev_e); - rotamer_assignments[ran_res] = local_ran_rot; + // if (traj_id == 0 ) { + // printf("deltaE: %f (%f - %f)\n", deltaE, new_e, prev_e); + // } + current_rotamer_assignment[ran_res] = local_ran_rot; current_total_energy = current_total_energy + deltaE; // for (int k=0; k < nres; ++k) { // float k_energy = alt_energies[k][thread_id]; @@ -497,18 +541,18 @@ MGPU_DEVICE float warp_wide_sim_annealing( current_total_energy = g.shfl(current_total_energy, accept_thread); new_best = g.shfl(new_best, accept_thread); if (new_best) { - for (int k = g.thread_rank(); k < nres; k += 32) { - best_rotamer_assignments[k] = rotamer_assignments[k]; + for (int k = g.thread_rank(); k < n_res; k += 32) { + best_rotamer_assignment[k] = current_rotamer_assignment[k]; } best_energy = current_total_energy; // g.shfl(best_energy, accept_thread); } - ++ntrials; - if (ntrials > 1000) { - ntrials = 0; - current_total_energy = - ig.total_energy_for_assignment_parallel(g, rotamer_assignments); + ++n_trials; + if (n_trials > 1000) { + n_trials = 0; + current_total_energy = ig.total_energy_for_assignment_parallel( + pose, g, current_rotamer_assignment); // if (g.thread_rank() == 0) { // printf("refresh total energy currentE %f\n", current_total_energy); // } @@ -518,116 +562,118 @@ MGPU_DEVICE float warp_wide_sim_annealing( // geometric cooling toward 0.3 // std::cout << "temperature " << temperature << " energy " << - // total_energy_for_assignment(nrotamers_for_res, oneb_offsets, + // total_energy_for_assignment(n_rotamers_for_res, oneb_offsets, // res_for_rot, nenergies, twob_offsets, energy1b, energy2b, // my_rotamer_assignment) << std::endl; temperature = 0.35 * (temperature - lo_temp) + lo_temp; } // end outer loop - float totalE = - ig.total_energy_for_assignment_parallel(g, rotamer_assignments); + float totalE = ig.total_energy_for_assignment_parallel( + pose, g, current_rotamer_assignment); return totalE; } -template -MGPU_DEVICE float spbr( - curandStatePhilox4_32_10_t* state, - cooperative_groups::thread_block_tile g, - InteractionGraph ig, - int warp_id, - int n_spbr, - TView spbr_rotamer_assignments, - TView spbr_perturbed_assignments) { - int const nres = ig.nrotamers_for_res.size(0); - int const nrotamers = ig.res_for_rot.size(0); - - float energy = ig.total_energy_for_assignment_parallel( - g, spbr_rotamer_assignments[warp_id]); - - for (int spbr_iteration = 0; spbr_iteration < n_spbr; ++spbr_iteration) { - // 1. pick a rotamer - int ran_rot; - if (g.thread_rank() == 0) { - float rand_num = curand_uniform(state); - ran_rot = int(rand_num * nrotamers) % nrotamers; - } - ran_rot = g.shfl(ran_rot, 0); - int const ran_res = ig.res_for_rot[ran_rot]; - int const ran_res_nrots = ig.nrotamers_for_res[ran_res]; - int const ran_rot_local = ran_rot - ig.oneb_offsets[ran_res]; - - // initialize the perturbed assignments array for this iteration. - // many of these will be overwritten, but for memory access efficiency - // copy everything over now. - for (int i = g.thread_rank(); i < nres; i += 32) { - int irot = - i == ran_res ? ran_rot_local : spbr_rotamer_assignments[warp_id][i]; - spbr_perturbed_assignments[warp_id][i] = irot; - } - - // 2. relax the neighbors of this residue - for (int i = 0; i < nres; ++i) { - // 4a. Find the lowest energy rotamer for residue i - if (ran_res == i || ig.nenergies[ran_res][i] == 0) { - continue; - } - - int my_best_rot = 0; - float my_best_rot_E = 9999; - int i_nrots = ig.nrotamers_for_res[i]; - for (int j = g.thread_rank(); j < i_nrots; j += 32) { - int const j_global = j + ig.oneb_offsets[i]; - float jE = ig.energy1b[j_global]; - for (int k = 0; k < nres; ++k) { - if (k == i || ig.nenergies[k][i] == 0) continue; - - int const k_rotamer = k == ran_res - ? ran_rot_local - : spbr_rotamer_assignments[warp_id][k]; - jE += ig.energy2b[ig.twob_offsets[k][i] + i_nrots * k_rotamer + j]; - } - - if (j == g.thread_rank() || jE < my_best_rot_E) { - my_best_rot = j; - my_best_rot_E = jE; - } - } - // now all threads compare: who has the lowest energy - // if (g.thread_rank() == 0) { - // printf("minimum\n"); - // } - float best_rot_E = - reduce_shfl_and_broadcast(g, my_best_rot_E, mgpu::minimum_t()); - int mine_is_best = best_rot_E == my_best_rot_E; - int scan_val = inclusive_scan_shfl(g, mine_is_best, mgpu::plus_t()); - if (mine_is_best && scan_val == 1) { - // exactly one thread saves the assigned rotamer to the - // spbr_perturbed_assignemnt array - spbr_perturbed_assignments[warp_id][i] = my_best_rot; - } - } - - // 5. compute the new total energy after relaxation - float alt_energy = ig.total_energy_for_assignment_parallel( - g, spbr_perturbed_assignments[warp_id]); - - // 6. if the energy decreases, accept the perturbed conformation - if (alt_energy < energy) { - // if (g.thread_rank() == 0) { - // printf("%d prevE %f newE %f\n", warp_id, energy, alt_energy); - // } - - energy = alt_energy; - for (int i = g.thread_rank(); i < nres; i += 32) { - spbr_rotamer_assignments[warp_id][i] = - spbr_perturbed_assignments[warp_id][i]; - } - } - } - return energy; -} +// template +// MGPU_DEVICE float spbr( +// curandStatePhilox4_32_10_t* state, +// cooperative_groups::thread_block_tile g, +// InteractionGraph ig, +// int warp_id, +// int n_spbr, +// TView spbr_rotamer_assignments, +// TView spbr_perturbed_assignments) { +// int const nres = ig.n_rotamers_for_res.size(0); +// int const nrotamers = ig.res_for_rot.size(0); +// +// float energy = ig.total_energy_for_assignment_parallel( +// g, spbr_rotamer_assignments[warp_id]); +// +// for (int spbr_iteration = 0; spbr_iteration < n_spbr; ++spbr_iteration) { +// // 1. pick a rotamer +// int ran_rot; +// if (g.thread_rank() == 0) { +// float rand_num = curand_uniform(state); +// ran_rot = int(rand_num * nrotamers) % nrotamers; +// } +// ran_rot = g.shfl(ran_rot, 0); +// int const ran_res = ig.res_for_rot[ran_rot]; +// int const ran_res_nrots = ig.nrotamers_for_res[ran_res]; +// int const ran_rot_local = ran_rot - ig.oneb_offsets[ran_res]; +// +// // initialize the perturbed assignments array for this iteration. +// // many of these will be overwritten, but for memory access efficiency +// // copy everything over now. +// for (int i = g.thread_rank(); i < nres; i += 32) { +// int irot = +// i == ran_res ? ran_rot_local : +// spbr_rotamer_assignments[warp_id][i]; +// spbr_perturbed_assignments[warp_id][i] = irot; +// } +// +// // 2. relax the neighbors of this residue +// for (int i = 0; i < nres; ++i) { +// // 4a. Find the lowest energy rotamer for residue i +// if (ran_res == i || ig.nenergies[ran_res][i] == 0) { +// continue; +// } +// +// int my_best_rot = 0; +// float my_best_rot_E = 9999; +// int i_nrots = ig.nrotamers_for_res[i]; +// for (int j = g.thread_rank(); j < i_nrots; j += 32) { +// int const j_global = j + ig.oneb_offsets[i]; +// float jE = ig.energy1b[j_global]; +// for (int k = 0; k < nres; ++k) { +// if (k == i || ig.nenergies[k][i] == 0) continue; +// +// int const k_rotamer = k == ran_res +// ? ran_rot_local +// : spbr_rotamer_assignments[warp_id][k]; +// jE += ig.energy2b[ig.twob_offsets[k][i] + i_nrots * k_rotamer + j]; +// } +// +// if (j == g.thread_rank() || jE < my_best_rot_E) { +// my_best_rot = j; +// my_best_rot_E = jE; +// } +// } +// // now all threads compare: who has the lowest energy +// // if (g.thread_rank() == 0) { +// // printf("minimum\n"); +// // } +// float best_rot_E = +// reduce_shfl_and_broadcast(g, my_best_rot_E, +// mgpu::minimum_t()); +// int mine_is_best = best_rot_E == my_best_rot_E; +// int scan_val = inclusive_scan_shfl(g, mine_is_best, +// mgpu::plus_t()); if (mine_is_best && scan_val == 1) { +// // exactly one thread saves the assigned rotamer to the +// // spbr_perturbed_assignemnt array +// spbr_perturbed_assignments[warp_id][i] = my_best_rot; +// } +// } +// +// // 5. compute the new total energy after relaxation +// float alt_energy = ig.total_energy_for_assignment_parallel( +// g, spbr_perturbed_assignments[warp_id]); +// +// // 6. if the energy decreases, accept the perturbed conformation +// if (alt_energy < energy) { +// // if (g.thread_rank() == 0) { +// // printf("%d prevE %f newE %f\n", warp_id, energy, alt_energy); +// // } +// +// energy = alt_energy; +// for (int i = g.thread_rank(); i < nres; i += 32) { +// spbr_rotamer_assignments[warp_id][i] = +// spbr_perturbed_assignments[warp_id][i]; +// } +// } +// } +// return energy; +// } // IG must respond to // - nres() @@ -640,25 +686,31 @@ MGPU_DEVICE float spbr( template struct Annealer { - static auto run_simulated_annealing(IG ig, int64_t seed) - -> std::tuple, TPack > { - int const nres = ig.nres_cpu(); // nrotamers_for_res.size(0); - int const nrotamers = ig.nrotamers_cpu(); // res_for_rot.size(0); + static auto run_simulated_annealing(IG ig, at::CUDAGeneratorImpl* gen) + -> std::tuple, TPack > { + int const n_poses = ig.n_poses_cpu(); + int const max_n_res = ig.max_n_res_cpu(); // nrotamers_for_res.size(0); + int const n_rotamers_total = + ig.n_rotamers_total_cpu(); // res_for_rot.size(0); + int const max_n_rotamers = ig.max_n_rotamers_per_pose_cpu(); + // printf( + // "n poses: %d, max_n_res %d, n_rotamers_total %d, max_n_rotamers + // %d\n", n_poses, max_n_res, n_rotamers_total, max_n_rotamers); int const n_hitemp_simA_traj = 2000; - int const n_hitemp_simA_threads = 32 * n_hitemp_simA_traj; + int const n_hitemp_simA_threads = 32 * n_poses * n_hitemp_simA_traj; float const round1_cut = 0.25; int const n_lotemp_expansions = 10; int const n_lotemp_simA_traj = int(n_hitemp_simA_traj * n_lotemp_expansions * round1_cut); - int const n_lotemp_simA_threads = 32 * n_lotemp_simA_traj; + int const n_lotemp_simA_threads = 32 * n_poses * n_lotemp_simA_traj; float const round2_cut = 0.25; int const n_fullquench_traj = int(n_lotemp_simA_traj * round2_cut); - int const n_fullquench_threads = 32 * n_fullquench_traj; + int const n_fullquench_threads = 32 * n_poses * n_fullquench_traj; int const n_outer_iterations_hitemp = 10; - int const n_inner_iterations_hitemp = nrotamers / 8; + int const n_inner_iterations_hitemp = max_n_rotamers / 8; int const n_outer_iterations_lotemp = 10; - int const n_inner_iterations_lotemp = nrotamers / 16; + int const n_inner_iterations_lotemp = max_n_rotamers / 16; float const high_temp_initial = 30; float const low_temp_initial = 0.3; float const high_temp_later = 0.2; @@ -667,43 +719,64 @@ struct Annealer { int const max_traj = std::max( std::max(n_hitemp_simA_traj, n_lotemp_simA_traj), n_fullquench_traj); - auto scores_hitemp_t = TPack::zeros(n_hitemp_simA_traj); - auto rotamer_assignments_hitemp_t = - TPack::zeros({n_hitemp_simA_traj, nres}); + auto scores_hitemp_t = + TPack::zeros({n_poses, n_hitemp_simA_traj}); + auto current_rotamer_assignments_hitemp_t = + TPack::zeros({n_poses, n_hitemp_simA_traj, max_n_res}); auto best_rotamer_assignments_hitemp_t = - TPack::zeros({n_hitemp_simA_traj, nres}); - auto rotamer_assignments_hitemp_quenchlite_t = - TPack::zeros({n_hitemp_simA_traj, nres}); - auto sorted_hitemp_traj_t = TPack::zeros(n_hitemp_simA_traj); - - auto scores_lotemp_t = TPack::zeros(n_lotemp_simA_traj); - auto rotamer_assignments_lotemp_t = - TPack::zeros({n_lotemp_simA_traj, nres}); + TPack::zeros({n_poses, n_hitemp_simA_traj, max_n_res}); + auto current_rotamer_assignments_hitemp_quenchlite_t = + TPack::zeros({n_poses, n_hitemp_simA_traj, max_n_res}); + auto sorted_hitemp_traj_t = + TPack::zeros({n_poses, n_hitemp_simA_traj}); + auto segment_heads_hitemp_t = TPack::zeros( + {n_poses}); // arange(n_poses) * n_hitemp_simA_traj + auto segment_heads_lotemp_t = TPack::zeros( + {n_poses}); // arange(n_poses) * n_lotemp_simA_traj + auto segment_heads_fullquench_t = TPack::zeros( + {n_poses}); // arange(n_poses) * n_fulquench_traj + + auto scores_lotemp_t = + TPack::zeros({n_poses, n_lotemp_simA_traj}); + auto current_rotamer_assignments_lotemp_t = + TPack::zeros({n_poses, n_lotemp_simA_traj, max_n_res}); auto best_rotamer_assignments_lotemp_t = - TPack::zeros({n_lotemp_simA_traj, nres}); - // auto rotamer_assignments_lotemp_quenchlite_t = TPack::zeros({n_lotemp_simA_traj, nres}); - auto sorted_lotemp_traj_t = TPack::zeros(n_lotemp_simA_traj); + TPack::zeros({n_poses, n_lotemp_simA_traj, max_n_res}); + auto sorted_lotemp_traj_t = + TPack::zeros({n_poses, n_lotemp_simA_traj}); auto scores_fullquench_t = - TPack::zeros({1, n_fullquench_traj}); - auto rotamer_assignments_fullquench_t = - TPack::zeros({n_fullquench_traj, nres}); + TPack::zeros({n_poses, n_fullquench_traj}); + auto current_rotamer_assignments_fullquench_t = + TPack::zeros({n_poses, n_fullquench_traj, max_n_res}); auto best_rotamer_assignments_fullquench_t = - TPack::zeros({n_fullquench_traj, nres}); + TPack::zeros({n_poses, n_fullquench_traj, max_n_res}); + auto sorted_fullquench_traj_t = + TPack::zeros({n_poses, n_hitemp_simA_traj}); + + auto scores_final_t = + TPack::zeros({n_poses, n_fullquench_traj}); + auto rotamer_assignments_final_t = + TPack::zeros({n_poses, n_fullquench_traj, max_n_res}); - auto quench_order_t = TPack::zeros({max_traj, nrotamers}); + auto quench_order_t = + TPack::zeros({n_poses, max_traj, max_n_rotamers}); auto scores_hitemp = scores_hitemp_t.view; - auto rotamer_assignments_hitemp = rotamer_assignments_hitemp_t.view; + auto current_rotamer_assignments_hitemp = + current_rotamer_assignments_hitemp_t.view; auto best_rotamer_assignments_hitemp = best_rotamer_assignments_hitemp_t.view; - auto rotamer_assignments_hitemp_quenchlite = - rotamer_assignments_hitemp_quenchlite_t.view; + auto current_rotamer_assignments_hitemp_quenchlite = + current_rotamer_assignments_hitemp_quenchlite_t.view; auto sorted_hitemp_traj = sorted_hitemp_traj_t.view; + auto segment_heads_hitemp = segment_heads_hitemp_t.view; + auto segment_heads_lotemp = segment_heads_lotemp_t.view; + auto segment_heads_fullquench = segment_heads_fullquench_t.view; auto scores_lotemp = scores_lotemp_t.view; - auto rotamer_assignments_lotemp = rotamer_assignments_lotemp_t.view; + auto current_rotamer_assignments_lotemp = + current_rotamer_assignments_lotemp_t.view; auto best_rotamer_assignments_lotemp = best_rotamer_assignments_lotemp_t.view; // auto rotamer_assignments_lotemp_quenchlite = @@ -711,203 +784,266 @@ struct Annealer { auto sorted_lotemp_traj = sorted_lotemp_traj_t.view; auto scores_fullquench = scores_fullquench_t.view; - auto rotamer_assignments_fullquench = rotamer_assignments_fullquench_t.view; + auto current_rotamer_assignments_fullquench = + current_rotamer_assignments_fullquench_t.view; auto best_rotamer_assignments_fullquench = best_rotamer_assignments_fullquench_t.view; + auto sorted_fullquench_traj = sorted_fullquench_traj_t.view; // auto sorted_fullquench_traj = sorted_lotem_traj_t.view; - auto quench_order = quench_order_t.view; + auto scores_final = scores_final_t.view; + auto rotamer_assignments_final = rotamer_assignments_final_t.view; - // This code will work for future versions of the torch/aten libraries, but - // not this one. - // // Increment the cuda generator - // // I know I need to increment this, but I am unsure by how much! - // std::pair rng_engine_inputs; - // at::CUDAGenerator * gen = at::cuda::detail::getDefaultCUDAGenerator(); - // { - // std::lock_guard lock(gen->mutex_); - // rng_engine_inputs = gen->philox_engine_inputs(nrotamers * 400 + nres); - // } + auto quench_order = quench_order_t.view; // Increment the seed (and capture the current seed) for the - // cuda generator. The number of calls to curand must be known - // by this statement. - // 1: nrotmaers*400 = 20 outer loop * nrotamers * 20 inner loop - // calls to either curand_uniform or curand_uniform4 in either - // the quench / non-quench cycles + - // 2: nres = the initial seed state of the system is created by - // picking a single random rotamer per residue. - /*auto philox_seed_hitemp = next_philox_seed( - nrotamers + // initial random rotamer assignment - n_outer_iterations_hitemp * n_inner_iterations_hitemp - + // hitemp annealing - (nrotamers / 31 - + nres) // hitemp random permutation of quenchlite rotamers - );*/ - - int hitemp_cnt = - nrotamers + // initial random rotamer assignment - n_outer_iterations_hitemp * n_inner_iterations_hitemp - + // hitemp annealing - (nrotamers / 31 - + nres); // hitemp random permutation of quenchlite rotamers - - /*auto philox_seed_lotemp = next_philox_seed( - n_outer_iterations_lotemp * n_outer_iterations_lotemp - + // lotemp annealing - (nrotamers / 31 - + nres) // lowtemp random permuation of quenchlite rotamers - );*/ - - int lotemp_cnt = - n_outer_iterations_lotemp * n_outer_iterations_lotemp + // cuda generator. The number of calls to curand per thread + // must be known. Most curand calls are handled by thread 0, + // so that's the one we'll count. + // + // We will overestimate the number of curand calls because + // each pose might have a different number of rotamers and will + // thus have a different number of curand calls. + // + // 1: initial random rotamer assignment: + // (nres-1)/32 + 1 curand calls per thread + // + // Warp wide simulated annealing: + // + // 2: the outeriterations * inner-iterations + // -- random rotamer picking + // -- MC accept-reject calls: + // n_poses * n_traj * n_outer * n_inner * 4, all performed by + // thread 0 + // + // 3: quench ordering during last stage: + // 4 calls to curand per n-quench-iterations + possibly one + // extra call to curand per max_n_rotamers if full-quench + // or + possibly one extra call to curand per max_n_rotamers / 31 + // if quench-lite, all performed by thread 0.. + + int const hitemp_cnt = + (max_n_res - 1) / 32 + 1 + // initial random rotamer assignment + n_outer_iterations_hitemp * n_inner_iterations_hitemp * 4 + + // hitemp annealing; curand4 + (max_n_rotamers * 4 + + max_n_rotamers + / 31); // hitemp random permutation of quenchlite rotamers + + int const lotemp_cnt = + n_outer_iterations_lotemp * n_outer_iterations_lotemp * 4 + // lotemp annealing - (nrotamers / 31 - + nres); // lowtemp random permuation of quenchlite rotamers - - // auto philox_seed_quench = next_philox_seed(nrotamers); - - int quench_cnt = nrotamers; + (max_n_rotamers * 4 + + max_n_rotamers + / 31); // lotemp random permuation of quenchlite rotamers + + int const fullquench_cnt = + max_n_rotamers * 5; // random permutation + 4 curands per iteration + + // Increment the cuda generator + at::PhiloxCudaState hitemp_philox_state; + at::PhiloxCudaState lotemp_philox_state; + at::PhiloxCudaState quench_philox_state; + { + std::lock_guard lock(gen->mutex_); + hitemp_philox_state = gen->philox_cuda_state(hitemp_cnt); + lotemp_philox_state = gen->philox_cuda_state(lotemp_cnt); + quench_philox_state = gen->philox_cuda_state(fullquench_cnt); + } auto hitemp_simulated_annealing = [=] MGPU_DEVICE(int thread_id) { + auto seeds = at::cuda::philox::unpack(hitemp_philox_state); curandStatePhilox4_32_10_t state; - curand_init(seed, thread_id, 0, &state); + curand_init(std::get<0>(seeds), thread_id, std::get<1>(seeds), &state); cooperative_groups::thread_block_tile<32> g = cooperative_groups::tiled_partition<32>( cooperative_groups::this_thread_block()); - int const warp_id = thread_id / 32; + int const cta_id = thread_id / 32; + int const pose = cta_id / n_hitemp_simA_traj; + int const traj_id = cta_id % n_hitemp_simA_traj; + // printf("hitemp thread %d cta %d pose %d traj_id %d\n", thread_id, + // cta_id, + // pose, traj_id); + + int const n_res = ig.n_res(pose); + int const n_rotamers = ig.n_rotamers(pose); if (g.thread_rank() == 0) { - sorted_hitemp_traj[warp_id] = warp_id; + sorted_hitemp_traj[pose][traj_id] = traj_id; + } + if (g.thread_rank() == 0 && traj_id == 0) { + // later we will run segmented sort for the trajectories + // for each Pose, so we need tensors of "segment heads" + // to state the indices at which the trajectory lists + // begin. + segment_heads_hitemp[pose] = pose * n_hitemp_simA_traj; + segment_heads_lotemp[pose] = pose * n_lotemp_simA_traj; + segment_heads_fullquench[pose] = pose * n_fullquench_traj; } - for (int i = g.thread_rank(); i < nres; i += 32) { - int const i_nrots = ig.nrotamers_for_res()[i]; - int chosen = int(curand_uniform(&state) * i_nrots) % i_nrots; - rotamer_assignments_hitemp[warp_id][i] = chosen; - best_rotamer_assignments_hitemp[warp_id][i] = chosen; + for (int i = g.thread_rank(); i < n_res; i += 32) { + int const i_n_rots = ig.n_rotamers_for_res()[pose][i]; + int chosen = int(curand_uniform(&state) * i_n_rots) % i_n_rots; + current_rotamer_assignments_hitemp[pose][traj_id][i] = chosen; + best_rotamer_assignments_hitemp[pose][traj_id][i] = chosen; } float rotstate_energy_after_high_temp = warp_wide_sim_annealing( + pose, + traj_id, &state, g, ig, - warp_id, - rotamer_assignments_hitemp[warp_id], - best_rotamer_assignments_hitemp[warp_id], - quench_order[warp_id], + current_rotamer_assignments_hitemp[pose][traj_id], + best_rotamer_assignments_hitemp[pose][traj_id], + quench_order[pose][traj_id], high_temp_initial, low_temp_initial, n_outer_iterations_hitemp, n_inner_iterations_hitemp, - nrotamers, + n_rotamers, // irrelevant; no quench here false, false); + // if (g.thread_rank() == 0) { + // printf( + // "hitemp wwsa done: pose %d traj %d E = %f\n", + // pose, + // traj_id, + // rotstate_energy_after_high_temp); + // } // Save the state before moving into quench - for (int i = g.thread_rank(); i < nres; i += 32) { - int i_assignment = best_rotamer_assignments_hitemp[warp_id][i]; - rotamer_assignments_hitemp[warp_id][i] = i_assignment; - rotamer_assignments_hitemp_quenchlite[warp_id][i] = i_assignment; + for (int i = g.thread_rank(); i < n_res; i += 32) { + int i_assignment = best_rotamer_assignments_hitemp[pose][traj_id][i]; + current_rotamer_assignments_hitemp[pose][traj_id][i] = i_assignment; + current_rotamer_assignments_hitemp_quenchlite[pose][traj_id][i] = + i_assignment; } float best_energy_after_high_temp = ig.total_energy_for_assignment_parallel( - g, best_rotamer_assignments_hitemp[warp_id]); + pose, g, best_rotamer_assignments_hitemp[pose][traj_id]); // ok, run quench lite as a way to predict where this rotamer assignment // will end up after low-temperature annealing float after_first_quench_lite_totalE = warp_wide_sim_annealing( + pose, + traj_id, &state, g, ig, - warp_id, - rotamer_assignments_hitemp_quenchlite[warp_id], - best_rotamer_assignments_hitemp[warp_id], - quench_order[warp_id], + current_rotamer_assignments_hitemp_quenchlite[pose][traj_id], + best_rotamer_assignments_hitemp[pose][traj_id], + quench_order[pose][traj_id], high_temp_initial, low_temp_initial, 1, // perform quench in first (ie last) iteration n_inner_iterations_hitemp, // irrelevant - nrotamers, + n_rotamers, true, true); if (g.thread_rank() == 0) { - scores_hitemp[warp_id] = after_first_quench_lite_totalE; + scores_hitemp[pose][traj_id] = after_first_quench_lite_totalE; } }; auto lotemp_simulated_annealing = [=] MGPU_DEVICE(int thread_id) { + auto seeds = at::cuda::philox::unpack(lotemp_philox_state); curandStatePhilox4_32_10_t state; - curand_init(seed, thread_id, hitemp_cnt, &state); + curand_init(std::get<0>(seeds), thread_id, std::get<1>(seeds), &state); cooperative_groups::thread_block_tile<32> g = cooperative_groups::tiled_partition<32>( cooperative_groups::this_thread_block()); - int const warp_id = thread_id / 32; - int const source_traj = sorted_hitemp_traj[warp_id / n_lotemp_expansions]; + int const cta_id = thread_id / 32; + int const pose = cta_id / n_lotemp_simA_traj; + int const traj_id = cta_id % n_lotemp_simA_traj; + int const source_traj = + sorted_hitemp_traj[pose][traj_id / n_lotemp_expansions]; + + int const n_res = ig.n_res(pose); + int const n_rotamers = ig.n_rotamers(pose); + + // printf("lotemp thread %d cta %d pose %d traj_id %d source_traj %d\n", + // thread_id, cta_id, + // pose, traj_id, source_traj); if (g.thread_rank() == 0) { - sorted_lotemp_traj[warp_id] = warp_id; + sorted_lotemp_traj[pose][traj_id] = traj_id; } // initialize the rotamer assignment from one of the top trajectories // of the high-temperature annealing trajectory - for (int i = g.thread_rank(); i < nres; i += 32) { - int i_rot = rotamer_assignments_hitemp[source_traj][i]; - rotamer_assignments_lotemp[warp_id][i] = i_rot; - best_rotamer_assignments_lotemp[warp_id][i] = i_rot; + for (int i = g.thread_rank(); i < n_res; i += 32) { + int i_rot = current_rotamer_assignments_hitemp[pose][source_traj][i]; + current_rotamer_assignments_lotemp[pose][traj_id][i] = i_rot; + best_rotamer_assignments_lotemp[pose][traj_id][i] = i_rot; } // Now run a low-temperature cooling trajectory float low_temp_totalE = warp_wide_sim_annealing( + pose, + traj_id, &state, g, ig, - warp_id, - rotamer_assignments_lotemp[warp_id], - best_rotamer_assignments_lotemp[warp_id], - quench_order[warp_id], + current_rotamer_assignments_lotemp[pose][traj_id], + best_rotamer_assignments_lotemp[pose][traj_id], + quench_order[pose][traj_id], high_temp_later, low_temp_later, n_outer_iterations_lotemp, n_inner_iterations_lotemp, - nrotamers, + n_rotamers, false, false); + // if (g.thread_rank() == 0) { + // printf( + // "lotemp wwsa done: pose %d traj %d E = %f\n", + // pose, + // traj_id, + // low_temp_totalE); + // } // now we'll run a quench-lite // ok, we will run quench lite on first state float after_lotemp_quench_lite_totalE = warp_wide_sim_annealing( + pose, + traj_id, &state, g, ig, - warp_id, - rotamer_assignments_lotemp[warp_id], - best_rotamer_assignments_lotemp[warp_id], - quench_order[warp_id], + current_rotamer_assignments_lotemp[pose][traj_id], + best_rotamer_assignments_lotemp[pose][traj_id], + quench_order[pose][traj_id], high_temp_later, low_temp_later, 1, // run quench on first (i.e. last) iteration n_inner_iterations_lotemp, // irrelevant - nrotamers, + n_rotamers, true, true); if (g.thread_rank() == 0) { - scores_lotemp[warp_id] = after_lotemp_quench_lite_totalE; + scores_lotemp[pose][traj_id] = after_lotemp_quench_lite_totalE; } }; - auto fullquench = [=] MGPU_DEVICE(int thread_id) { + auto fullquench = ([=] MGPU_DEVICE(int thread_id) { + auto seeds = at::cuda::philox::unpack(quench_philox_state); curandStatePhilox4_32_10_t state; - curand_init(seed, thread_id, hitemp_cnt + lotemp_cnt, &state); + curand_init(std::get<0>(seeds), thread_id, std::get<1>(seeds), &state); cooperative_groups::thread_block_tile<32> g = cooperative_groups::tiled_partition<32>( cooperative_groups::this_thread_block()); - int const warp_id = thread_id / 32; - int const source_traj = sorted_lotemp_traj[warp_id]; + int const cta_id = thread_id / 32; + int const pose = cta_id / n_fullquench_traj; + int const traj_id = cta_id % n_fullquench_traj; + int const source_traj = sorted_lotemp_traj[pose][traj_id]; + + int const n_res = ig.n_res(pose); + int const n_rotamers = ig.n_rotamers(pose); // if (g.thread_rank() == 0) { // printf("warp %d fullquench source_traj %d (%d) %f\n", warp_id, // source_traj, @@ -916,105 +1052,165 @@ struct Annealer { // initialize the rotamer assignment from one of the top trajectories // of the high-temperature annealing trajectory - for (int i = g.thread_rank(); i < nres; i += 32) { - int i_rot = rotamer_assignments_lotemp[source_traj][i]; - rotamer_assignments_fullquench[warp_id][i] = i_rot; - best_rotamer_assignments_fullquench[warp_id][i] = i_rot; + for (int i = g.thread_rank(); i < n_res; i += 32) { + int i_rot = current_rotamer_assignments_lotemp[pose][source_traj][i]; + current_rotamer_assignments_fullquench[pose][traj_id][i] = i_rot; + best_rotamer_assignments_fullquench[pose][traj_id][i] = i_rot; } float after_full_quench_totalE = 0; for (int i = 0; i < 1; ++i) { after_full_quench_totalE = warp_wide_sim_annealing( + pose, + traj_id, &state, g, ig, - warp_id, - rotamer_assignments_fullquench[warp_id], - best_rotamer_assignments_fullquench[warp_id], - quench_order[warp_id], + current_rotamer_assignments_fullquench[pose][traj_id], + best_rotamer_assignments_fullquench[pose][traj_id], + quench_order[pose][traj_id], high_temp_later, low_temp_later, 1, // run quench on first (ie last) iteration n_inner_iterations_lotemp, - nrotamers, + n_rotamers, true, false); } if (g.thread_rank() == 0) { - scores_fullquench[0][warp_id] = after_full_quench_totalE; + scores_fullquench[pose][traj_id] = after_full_quench_totalE; } - }; + }); + + auto final_reindexing = ([=] MGPU_DEVICE(int thread_id) { + cooperative_groups::thread_block_tile<32> g = + cooperative_groups::tiled_partition<32>( + cooperative_groups::this_thread_block()); + int const cta_id = thread_id / 32; + int const pose = cta_id / n_fullquench_traj; + int const traj_id = cta_id % n_fullquench_traj; + int const source_traj = sorted_fullquench_traj[pose][traj_id]; + int const n_res = ig.n_res(pose); + if (g.thread_rank() == 0) { + scores_final[pose][traj_id] = scores_fullquench[pose][source_traj]; + } + for (int i = g.thread_rank(); i < n_res; i += 32) { + rotamer_assignments_final[pose][traj_id][i] = + best_rotamer_assignments_fullquench[pose][source_traj][i]; + } + }); mgpu::standard_context_t context; - mgpu::transform<128, 1>( + // printf("launch hitemp\n"); + mgpu::transform<32, 1>( hitemp_simulated_annealing, n_hitemp_simA_threads, context); - mgpu::mergesort( + + // now let's rank the trajectories for each pose + // printf("launch segsort\n"); + mgpu::segmented_sort( scores_hitemp.data(), sorted_hitemp_traj.data(), n_hitemp_simA_traj, + segment_heads_hitemp.data(), + n_poses, mgpu::less_t(), context); - mgpu::transform<128, 1>( + // printf("temp no launch lotemp\n"); + mgpu::transform<32, 1>( lotemp_simulated_annealing, n_lotemp_simA_threads, context); - mgpu::mergesort( + + // printf("temp no launch segsort2\n"); + mgpu::segmented_sort( scores_lotemp.data(), sorted_lotemp_traj.data(), n_lotemp_simA_traj, + segment_heads_lotemp.data(), + n_poses, + mgpu::less_t(), + context); + + // printf("temp no launch fullquench\n"); + mgpu::transform<32, 1>(fullquench, n_fullquench_threads, context); + + mgpu::segmented_sort( + scores_fullquench.data(), + sorted_fullquench_traj.data(), + n_fullquench_traj, + segment_heads_fullquench.data(), + n_poses, mgpu::less_t(), context); - mgpu::transform<128, 1>(fullquench, n_fullquench_threads, context); + // printf("temp no launch fullquench\n"); + mgpu::transform<32, 1>(final_reindexing, n_fullquench_threads, context); - return {scores_fullquench_t, rotamer_assignments_fullquench_t}; + printf("done!\n"); + return {scores_final_t, rotamer_assignments_final_t}; } }; template -struct AnnealerDispatch { - static auto forward( - TView nrotamers_for_res, - TView oneb_offsets, - TView res_for_rot, - TView respair_nenergies, - TView chunk_size, - TView chunk_offset_offsets, - TView twob_offsets, - TView fine_chunk_offsets, - TView energy1b, - TView energy2b, - int64_t seed) -> std::tuple, TPack > { - clock_t start = clock(); - - InteractionGraph ig( - {nrotamers_for_res, - oneb_offsets, - res_for_rot, - respair_nenergies, - chunk_size, - chunk_offset_offsets, - twob_offsets, - fine_chunk_offsets, - energy1b, - energy2b}); - - auto result = - Annealer >::run_simulated_annealing( - ig, seed); - - cudaDeviceSynchronize(); - clock_t stop = clock(); - std::cout << "GPU simulated annealing in " - << ((double)stop - start) / CLOCKS_PER_SEC << " seconds" - << std::endl; - - return result; - } -}; +auto AnnealerDispatch::forward( + int max_n_rotamers_per_pose, + TView pose_n_res, + TView pose_n_rotamers, + TView pose_rotamer_offset, + TView n_rotamers_for_res, + TView oneb_offsets, + TView res_for_rot, + int32_t chunk_size, + TView chunk_offset_offsets, + TView chunk_offsets, + TView energy1b, + TView energy2b) + -> std::tuple, TPack > { + clock_t start = clock(); + + InteractionGraph ig( + {max_n_rotamers_per_pose, + pose_n_res, + pose_n_rotamers, + pose_rotamer_offset, + n_rotamers_for_res, + oneb_offsets, + res_for_rot, + chunk_size, + chunk_offset_offsets, + chunk_offsets, + energy1b, + energy2b}); + + auto gen = at::get_generator_or_default( + std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + + auto result = + Annealer >::run_simulated_annealing( + ig, gen); + + cudaDeviceSynchronize(); + clock_t stop = clock(); + std::cout << "GPU simulated annealing in " + << ((double)stop - start) / CLOCKS_PER_SEC << " seconds" + << std::endl; + + return result; +} template struct AnnealerDispatch; +template struct InteractionGraphBuilder< + score::common::DeviceOperations, + tmol::Device::CUDA, + float, + int64_t>; +template struct InteractionGraphBuilder< + score::common::DeviceOperations, + tmol::Device::CUDA, + double, + int64_t>; + } // namespace compiled } // namespace pack } // namespace tmol diff --git a/tmol/pack/compiled/compiled.impl.hh b/tmol/pack/compiled/compiled.impl.hh new file mode 100644 index 000000000..bb10712aa --- /dev/null +++ b/tmol/pack/compiled/compiled.impl.hh @@ -0,0 +1,320 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include + +#include "annealer.hh" + +namespace tmol { +namespace pack { +namespace compiled { + +template < + template class DeviceDispatch, + tmol::Device D, + typename Real, + typename Int> +auto InteractionGraphBuilder::f( + int const chunk_size, + TView n_rots_for_pose, + TView rot_offset_for_pose, + TView n_rots_for_block, + TView rot_offset_for_block, + TView pose_for_rot, + TView block_type_ind_for_rot, + TView block_ind_for_rot, + TView sparse_inds, // if we are ever dealing w > 4B + // rotamers, we are in trouble + TView sparse_energies) + -> std::tuple< + TPack, + TPack, + TPack, + TPack > { + int const n_poses = n_rots_for_pose.size(0); + int const n_rotamers = pose_for_rot.size(0); + int const max_n_blocks = n_rots_for_block.size(1); + int const n_sparse_entries = sparse_inds.size(1); + + assert(rot_offset_for_pose.size(0) == n_poses); + assert(n_rots_for_block.size(0) == n_poses); + assert(block_type_ind_for_rot.size(0) == n_rotamers); + assert(block_ind_for_rot.size(0) == n_rotamers); + assert(sparse_inds.size(0) == 3); + assert(sparse_energies.size(0) == n_sparse_entries); + + auto energy1b_tp = TPack::zeros({n_rotamers}); + auto energy1b = energy1b_tp.view; + auto n_chunks_for_block_tp = + TPack::zeros({n_poses, max_n_blocks}); + auto n_chunks_for_block = n_chunks_for_block_tp.view; + + LAUNCH_BOX_32; + auto count_n_chunks_for_block = ([=] TMOL_DEVICE_FUNC(int index) { + int const pose = index / max_n_blocks; + int const block = index % max_n_blocks; + int const n_rots = n_rots_for_block[pose][block]; + if (n_rots != 0) { + int const n_chunks = (n_rots - 1) / chunk_size + 1; + n_chunks_for_block[pose][block] = n_chunks; + // printf("n_chunks_for_block[%d][%d] == %d\n", pose, block, n_rots); + } + }); + DeviceDispatch::template forall( + n_poses * max_n_blocks, count_n_chunks_for_block); + + auto respair_is_adjacent_tp = + TPack::zeros({n_poses, max_n_blocks, max_n_blocks}); + auto respair_is_adjacent = respair_is_adjacent_tp.view; + + auto note_adjacent_respairs = ([=] TMOL_DEVICE_FUNC(int index) { + int const pose = sparse_inds[0][index]; + int const rot1 = sparse_inds[1][index]; + int const rot2 = sparse_inds[2][index]; + int const block1 = block_ind_for_rot[rot1]; + int const block2 = block_ind_for_rot[rot2]; + if (block1 == block2) { + return; // printf("block1 %d == block2 %d for rot1 %d rot2 %d\n", block1, + // block2, rot1, rot2); + } + // Assert: block1 < block2 + respair_is_adjacent[pose][block1][block2] = 1; + // if (index < 100) { + // printf("note_adjacent_respairs %d %d %d %d %d\n", pose, rot1, rot2, + // block1, block2); + // } + }); + DeviceDispatch::template forall( + n_sparse_entries, note_adjacent_respairs); + + auto n_chunks_for_block_pair_tp = + TPack::zeros({n_poses, max_n_blocks, max_n_blocks}); + auto n_chunks_for_block_pair = n_chunks_for_block_pair_tp.view; + + auto note_n_chunks_for_block_pair = ([=] TMOL_DEVICE_FUNC(int index) { + int const pose = index / (max_n_blocks * max_n_blocks); + index = index - pose * max_n_blocks * max_n_blocks; + int const block1 = index / max_n_blocks; + int const block2 = index % max_n_blocks; + + // We don't have to worry about block1 > block2 as those will not + // have entries in the sparse_inds input tensors, but we do + // have to worry about block1 == block2 for one-body energies + if (respair_is_adjacent[pose][block1][block2]) { + int const n_chunks1 = n_chunks_for_block[pose][block1]; + int const n_chunks2 = n_chunks_for_block[pose][block2]; + int const n_chunk_pairs = n_chunks1 * n_chunks2; + // printf("respair adjacent %d %d %d; nchunk pairs %d\n", pose, block1, + // block2, n_chunk_pairs); + n_chunks_for_block_pair[pose][block1][block2] = n_chunk_pairs; + n_chunks_for_block_pair[pose][block2][block1] = n_chunk_pairs; + } + }); + DeviceDispatch::template forall( + n_poses * max_n_blocks * max_n_blocks, note_n_chunks_for_block_pair); + // printf("note_n_chunks_for_block_pair\n"); + + auto chunk_pair_offset_for_block_pair_tp = + TPack::zeros({n_poses, max_n_blocks, max_n_blocks}); + auto chunk_pair_offset_for_block_pair = + chunk_pair_offset_for_block_pair_tp.view; + + // Okay, now let's figure out which chunk pairs are near each other + int const n_adjacent_chunk_pairs_total = + DeviceDispatch::template scan_and_return_total( + n_chunks_for_block_pair.data(), + chunk_pair_offset_for_block_pair.data(), + n_poses * max_n_blocks * max_n_blocks, + mgpu::plus_t()); + // printf("n_adjacent_chunk_pairs_total %d\n", n_adjacent_chunk_pairs_total); + + auto chunk_pair_adjacency_tp = + TPack::zeros({n_adjacent_chunk_pairs_total}); + auto chunk_pair_adjacency = chunk_pair_adjacency_tp.view; + + auto note_adjacent_chunk_pairs = ([=] TMOL_DEVICE_FUNC(int index) { + int const pose = sparse_inds[0][index]; + int const rot1 = sparse_inds[1][index]; + int const rot2 = sparse_inds[2][index]; + int const block1 = block_ind_for_rot[rot1]; + int const block2 = block_ind_for_rot[rot2]; + if (block1 == block2) { + return; + } + + int const block1_rot_offset = rot_offset_for_block[pose][block1]; + int const block2_rot_offset = rot_offset_for_block[pose][block2]; + int const local_rot1 = rot1 - block1_rot_offset; + int const local_rot2 = rot2 - block2_rot_offset; + int const chunk1 = local_rot1 / chunk_size; + int const chunk2 = local_rot2 / chunk_size; + int const n_rots_block1 = n_rots_for_block[pose][block1]; + int const n_rots_block2 = n_rots_for_block[pose][block2]; + int const n_chunks1 = (n_rots_block1 - 1) / chunk_size + 1; + int const n_chunks2 = (n_rots_block2 - 1) / chunk_size + 1; + + int const overhang1 = n_rots_block1 - chunk1 * chunk_size; + int const overhang2 = n_rots_block2 - chunk2 * chunk_size; + int const chunk1_size = (overhang1 > chunk_size ? chunk_size : overhang1); + int const chunk2_size = (overhang2 > chunk_size ? chunk_size : overhang2); + + int const block_pair_chunk_offset_ij = + chunk_pair_offset_for_block_pair[pose][block1][block2]; + int const block_pair_chunk_offset_ji = + chunk_pair_offset_for_block_pair[pose][block2][block1]; + + // if (index < 100) { + // printf("pose %d rot1 %d rot2 %d block1 %d block2 %d block1_rot_offset + // %d block2_rot_offset %d local_rot1 %d local_rot2 %d chunk1 %d chunk2 %d + // n_rots_block1 %d n_rots_block2 %d n_chunks1 %d n_chunks2 %d overhang1 + // %d overhang2 %d chunk1_size %d chunk2_size %d + // block_pair_chunk_offset_ij %d block_pair_chunk_offset_ji %d\n", pose, + // rot1, rot2, block1, block2, block1_rot_offset, block2_rot_offset, + // local_rot1, local_rot2, chunk1, chunk2, n_rots_block1, n_rots_block2, + // n_chunks1, n_chunks2, overhang1, overhang2, chunk1_size, chunk2_size, + // block_pair_chunk_offset_ij, block_pair_chunk_offset_ji); + // } + + // multiple threads will write exactly these values to these entries in the + // chunk_pair_adjacency table + chunk_pair_adjacency + [block_pair_chunk_offset_ij + chunk1 * n_chunks2 + chunk2] = + chunk1_size * chunk2_size; + chunk_pair_adjacency + [block_pair_chunk_offset_ji + chunk2 * n_chunks1 + chunk1] = + chunk1_size * chunk2_size; + }); + DeviceDispatch::template forall( + n_sparse_entries, note_adjacent_chunk_pairs); + + auto chunk_pair_offsets_tp = + TPack::zeros({n_adjacent_chunk_pairs_total}); + auto chunk_pair_offsets = chunk_pair_offsets_tp.view; + + int64_t const n_two_body_energies = + DeviceDispatch::template scan_and_return_total( + chunk_pair_adjacency.data(), + chunk_pair_offsets.data(), + n_adjacent_chunk_pairs_total, + mgpu::plus_t()); + + auto energy2b_tp = TPack::zeros({n_two_body_energies}); + auto energy2b = energy2b_tp.view; + + auto record_energies_in_energy1b_and_energy2b = ([=] TMOL_DEVICE_FUNC( + int index) { + int const pose = sparse_inds[0][index]; + int const rot1 = sparse_inds[1][index]; + int const rot2 = sparse_inds[2][index]; + Real const energy = sparse_energies[index]; + int const block1 = block_ind_for_rot[rot1]; + int const block2 = block_ind_for_rot[rot2]; + if (block1 == block2) { + // printf("setting one body energy? %d %d %d %d of %d\n", block1, block2, + // rot1, rot2, n_rotamers); + energy1b[rot1] = energy; + } else { + int const block1_rot_offset = rot_offset_for_block[pose][block1]; + int const block2_rot_offset = rot_offset_for_block[pose][block2]; + int const local_rot1 = rot1 - block1_rot_offset; + int const local_rot2 = rot2 - block2_rot_offset; + int const chunk1 = local_rot1 / chunk_size; + int const chunk2 = local_rot2 / chunk_size; + int const n_rots_block1 = n_rots_for_block[pose][block1]; + int const n_rots_block2 = n_rots_for_block[pose][block2]; + int const n_chunks1 = (n_rots_block1 - 1) / chunk_size + 1; + int const n_chunks2 = (n_rots_block2 - 1) / chunk_size + 1; + + int const overhang1 = n_rots_block1 - chunk1 * chunk_size; + int const overhang2 = n_rots_block2 - chunk2 * chunk_size; + int const chunk1_size = (overhang1 > chunk_size ? chunk_size : overhang1); + int const chunk2_size = (overhang2 > chunk_size ? chunk_size : overhang2); + + int const rot_ind_wi_chunk1 = local_rot1 - chunk1 * chunk_size; + int const rot_ind_wi_chunk2 = local_rot2 - chunk2 * chunk_size; + + int const block_pair_chunk_offset_ij = + chunk_pair_offset_for_block_pair[pose][block1][block2]; + int const block_pair_chunk_offset_ji = + chunk_pair_offset_for_block_pair[pose][block2][block1]; + + int const chunk_offset_ij = chunk_pair_offsets + [block_pair_chunk_offset_ij + chunk1 * n_chunks2 + chunk2]; + int const chunk_offset_ji = chunk_pair_offsets + [block_pair_chunk_offset_ji + chunk2 * n_chunks1 + chunk1]; + + // if (index < 100) { + // printf("record: pose %d rot1 %d rot2 %d E %6.3f block1 %d block2 %d + // block1_rot_offset %d block2_rot_offset %d local_rot1 %d local_rot2 %d + // chunk1 %d chunk2 %d n_rots_block1 %d n_rots_block2 %d n_chunks1 %d + // n_chunks2 %d overhang1 %d overhang2 %d chunk1_size %d chunk2_size %d + // rot_ind_wi_chunk1 %d rot_ind_wi_chunk2 %d block_pair_chunk_offset_ij + // %d block_pair_chunk_offset_ji %d\n", pose, rot1, rot2, energy, + // block1, block2, block1_rot_offset, block2_rot_offset, local_rot1, + // local_rot2, chunk1, chunk2, n_rots_block1, n_rots_block2, n_chunks1, + // n_chunks2, overhang1, overhang2, chunk1_size, chunk2_size, + // rot_ind_wi_chunk1, rot_ind_wi_chunk2, block_pair_chunk_offset_ij, + // block_pair_chunk_offset_ji); + // } + + energy2b + [chunk_offset_ij + rot_ind_wi_chunk1 * chunk2_size + + rot_ind_wi_chunk2] = energy; + energy2b + [chunk_offset_ji + rot_ind_wi_chunk2 * chunk1_size + + rot_ind_wi_chunk1] = energy; + } + }); + DeviceDispatch::template forall( + n_sparse_entries, record_energies_in_energy1b_and_energy2b); + + // Mark the chunk_pair_offset_for_block_pair that are not adjacent w/ -1s + // Mark the chunk_pair_offsets that are not adjacent w/ -1s + + auto sentinel_out_non_adjacent_block_pairs = + ([=] TMOL_DEVICE_FUNC(int index) { + int const pose = index / (max_n_blocks * max_n_blocks); + index = index - pose * max_n_blocks * max_n_blocks; + int const block1 = index / max_n_blocks; + int const block2 = index % max_n_blocks; + + // We don't have to worry about block1 >= block2 as those will not + // have entries in the sparse_inds input tensors + if (block1 <= block2 && !respair_is_adjacent[pose][block1][block2]) { + chunk_pair_offset_for_block_pair[pose][block1][block2] = -1; + chunk_pair_offset_for_block_pair[pose][block2][block1] = -1; + } + }); + DeviceDispatch::template forall( + n_poses * max_n_blocks * max_n_blocks, + sentinel_out_non_adjacent_block_pairs); + + auto sentinel_out_non_adjacent_chunk_pairs = + ([=] TMOL_DEVICE_FUNC(int index) { + int const n_pairs_for_chunk = chunk_pair_adjacency.data()[index]; + if (n_pairs_for_chunk == 0) { + // if (index < 100) { + // printf("Non adjacent chunk pair %d\n", index); + // } + chunk_pair_offsets[index] = -1; + } + }); + DeviceDispatch::template forall( + n_adjacent_chunk_pairs_total, sentinel_out_non_adjacent_chunk_pairs); + + return std::make_tuple( + energy1b_tp, + chunk_pair_offset_for_block_pair_tp, + chunk_pair_offsets_tp, + energy2b_tp); +} + +} // namespace compiled +} // namespace pack +} // namespace tmol diff --git a/tmol/pack/compiled/compiled.ops.cpp b/tmol/pack/compiled/compiled.ops.cpp index aca6d09cf..1d66f5e2d 100644 --- a/tmol/pack/compiled/compiled.ops.cpp +++ b/tmol/pack/compiled/compiled.ops.cpp @@ -10,6 +10,7 @@ #include #include +#include #include "annealer.hh" #include "simulated_annealing.hh" @@ -20,18 +21,72 @@ namespace compiled { using torch::Tensor; +std::vector build_interaction_graph( + int64_t const chunk_size, + Tensor n_rots_for_pose, + Tensor rot_offset_for_pose, + Tensor n_rots_for_block, + Tensor rot_offset_for_block, + Tensor pose_for_rot, + Tensor block_type_ind_for_rot, + Tensor block_ind_for_rot, + Tensor sparse_inds, + Tensor sparse_energies) { + nvtx_range_push("pack_build_ig"); + at::Tensor energy1b; + at::Tensor chunk_pair_offset_for_block_pair; + at::Tensor chunk_pair_offset; + at::Tensor energy2b; + + using Int = int64_t; + + TMOL_DISPATCH_FLOATING_DEVICE( + sparse_energies.options(), "pack_build_ig", ([&] { + constexpr tmol::Device Dev = device_t; + using Real = scalar_t; + + auto result = InteractionGraphBuilder< + score::common::DeviceOperations, + Dev, + Real, + Int>:: + f(chunk_size, + TCAST(n_rots_for_pose), + TCAST(rot_offset_for_pose), + TCAST(n_rots_for_block), + TCAST(rot_offset_for_block), + TCAST(pose_for_rot), + TCAST(block_type_ind_for_rot), + TCAST(block_ind_for_rot), + TCAST(sparse_inds), + TCAST(sparse_energies)); + energy1b = std::get<0>(result).tensor; + chunk_pair_offset_for_block_pair = std::get<1>(result).tensor; + chunk_pair_offset = std::get<2>(result).tensor; + energy2b = std::get<3>(result).tensor; + })); + + std::vector result( + {energy1b, + chunk_pair_offset_for_block_pair, + chunk_pair_offset, + energy2b}); + return result; +} + std::vector anneal( - Tensor nrotamers_for_res, + int64_t max_n_rotamers_per_pose, + Tensor pose_n_res, + Tensor pose_n_rotamers, + Tensor pose_rotamer_offset, + Tensor n_rotamers_for_res, Tensor oneb_offsets, Tensor res_for_rot, - Tensor respair_nenergies, - Tensor chunk_size, + int64_t chunk_size, Tensor chunk_offset_offsets, - Tensor twob_offsets, - Tensor fine_chunk_offsets, + Tensor chunk_offsets, Tensor energy1b, - Tensor energy2b, - int64_t seed) { + Tensor energy2b) { nvtx_range_push("pack_anneal"); at::Tensor scores; at::Tensor rotamer_assignments; @@ -39,19 +94,19 @@ std::vector anneal( TMOL_DISPATCH_FLOATING_DEVICE(energy1b.options(), "pack_anneal", ([&] { constexpr tmol::Device Dev = device_t; - std::cout << "HOLA!" << std::endl; auto result = AnnealerDispatch::forward( - TCAST(nrotamers_for_res), + max_n_rotamers_per_pose, + TCAST(pose_n_res), + TCAST(pose_n_rotamers), + TCAST(pose_rotamer_offset), + TCAST(n_rotamers_for_res), TCAST(oneb_offsets), TCAST(res_for_rot), - TCAST(respair_nenergies), - TCAST(chunk_size), + chunk_size, TCAST(chunk_offset_offsets), - TCAST(twob_offsets), - TCAST(fine_chunk_offsets), + TCAST(chunk_offsets), TCAST(energy1b), - TCAST(energy2b), - seed); + TCAST(energy2b)); scores = std::get<0>(result).tensor; rotamer_assignments = std::get<1>(result).tensor; @@ -61,35 +116,31 @@ std::vector anneal( return result; } -TPack compute_energies_for_assignments( - TView nrotamers_for_res, - TView oneb_offsets, - TView res_for_rot, - TView respair_nenergies, - TView chunk_size, - TView chunk_offset_offsets, - TView twob_offsets, - TView fine_chunk_offsets, +TPack compute_energies_for_assignments( + TView n_rotamers_for_res, + TView oneb_offsets, + int32_t chunk_size, + TView chunk_offset_offsets, + TView chunk_offsets, TView energy1b, TView energy2b, - TView rotamer_assignments) { - int n_assignments = rotamer_assignments.size(0); - auto scores_t = TPack::zeros({n_assignments}); + TView rotamer_assignments) { + int n_poses = rotamer_assignments.size(0); + int n_traj = rotamer_assignments.size(1); + auto scores_t = TPack::zeros({n_poses, n_traj}); auto scores = scores_t.view; - for (int i = 0; i < n_assignments; ++i) { - scores[i] = total_energy_for_assignment( - nrotamers_for_res, - oneb_offsets, - res_for_rot, - respair_nenergies, - chunk_size, - chunk_offset_offsets, - twob_offsets, - fine_chunk_offsets, - energy1b, - energy2b, - rotamer_assignments, - i); + for (int pose = 0; pose < n_poses; ++pose) { + for (int i = 0; i < n_traj; ++i) { + scores[pose][i] = total_energy_for_assignment( + n_rotamers_for_res[pose], + oneb_offsets[pose], + chunk_size, + chunk_offset_offsets[pose], + chunk_offsets, + energy1b, + energy2b, + rotamer_assignments[pose][i]); + } } return scores_t; } @@ -97,24 +148,18 @@ TPack compute_energies_for_assignments( torch::Tensor validate_energies( Tensor nrotamers_for_res, Tensor oneb_offsets, - Tensor res_for_rot, - Tensor respair_nenergies, - Tensor chunk_size, + int64_t chunk_size, Tensor chunk_offset_offsets, - Tensor twob_offsets, - Tensor fine_chunk_offsets, + Tensor chunk_offsets, Tensor energy1b, Tensor energy2b, Tensor rotamer_assignments) { auto result = compute_energies_for_assignments( TCAST(nrotamers_for_res), TCAST(oneb_offsets), - TCAST(res_for_rot), - TCAST(respair_nenergies), - TCAST(chunk_size), + int32_t(chunk_size), TCAST(chunk_offset_offsets), - TCAST(twob_offsets), - TCAST(fine_chunk_offsets), + TCAST(chunk_offsets), TCAST(energy1b), TCAST(energy2b), TCAST(rotamer_assignments)); @@ -132,6 +177,7 @@ static auto registry = torch::jit::RegisterOperators() TORCH_LIBRARY_(TORCH_EXTENSION_NAME, m) { m.def("pack_anneal", &anneal); m.def("validate_energies", &validate_energies); + m.def("build_interaction_graph", &build_interaction_graph); } } // namespace compiled diff --git a/tmol/pack/compiled/compiled.py b/tmol/pack/compiled/compiled.py index 87fcb44eb..911e036cd 100644 --- a/tmol/pack/compiled/compiled.py +++ b/tmol/pack/compiled/compiled.py @@ -12,3 +12,4 @@ _ops = getattr(torch.ops, modulename(__name__)) pack_anneal = _ops.pack_anneal validate_energies = _ops.validate_energies +build_interaction_graph = _ops.build_interaction_graph diff --git a/tmol/pack/compiled/simulated_annealing.hh b/tmol/pack/compiled/simulated_annealing.hh index 34a1bb382..77a744616 100644 --- a/tmol/pack/compiled/simulated_annealing.hh +++ b/tmol/pack/compiled/simulated_annealing.hh @@ -1,5 +1,7 @@ #pragma once +#include // std::min + namespace tmol { namespace pack { namespace compiled { @@ -63,62 +65,88 @@ inline #endif float total_energy_for_assignment( - TView nrotamers_for_res, - TView oneb_offsets, - TView res_for_rot, - TView respair_nenergies, - TView chunk_size_t, - TView chunk_offset_offsets, - TView twob_offsets, - TView fine_chunk_offsets, - TView energy1b, - TView energy2b, - TView rotamer_assignment, - TView pair_energies, - int rotassign_dim0 // i.e. thread_id + TensorAccessor n_rotamers_for_res, // max-n-res + TensorAccessor oneb_offsets, // max-n-res + int32_t const chunk_size, + TensorAccessor + chunk_offset_offsets, // max-n-res x max-n-res + TView chunk_offsets, // n-interacting-chunk-pairs + TView energy1b, // n-rotamers-total + TView energy2b, // n-interacting-rotamer-pairs + TensorAccessor + rotamer_assignment, // local rotamer indices; max-n-res + TensorAccessor + current_pair_energies // max-n-res x max-n-res ) { +#ifndef __CUDACC__ + using std::min; +#endif + + // Read the energies from energ1b and energy2b for the given + // rotamer_assignment (represented as the local indices for + // each rotamer on each block) and record each energy in the + // current_pair_energies table. float totalE = 0; - int const nres = nrotamers_for_res.size(0); - int const chunk_size = chunk_size_t[0]; - for (int i = 0; i < nres; ++i) { - int const irot_local = rotamer_assignment[rotassign_dim0][i]; + int const n_res = n_rotamers_for_res.size(0); + for (int i = 0; i < n_res; ++i) { + int const irot_local = rotamer_assignment[i]; + + if (irot_local == -1) { + // unassigned rotamer or residue off the end for the Pose + for (int j = 0; j < n_res; ++j) { + current_pair_energies[i][j] = 0; + current_pair_energies[j][i] = 0; + } + continue; + } int const irot_global = irot_local + oneb_offsets[i]; - int const ires_nrots = nrotamers_for_res[i]; - int const ires_nchunks = (ires_nrots - 1) / chunk_size + 1; + int const ires_n_rots = n_rotamers_for_res[i]; + int const ires_n_chunks = (ires_n_rots - 1) / chunk_size + 1; int const irot_chunk = irot_local / chunk_size; int const irot_in_chunk = irot_local - chunk_size * irot_chunk; int const irot_chunk_size = - std::min(chunk_size, ires_nrots - chunk_size * irot_chunk); + min(chunk_size, ires_n_rots - chunk_size * irot_chunk); totalE += energy1b[irot_global]; - for (int j = i + 1; j < nres; ++j) { - int const jrot_local = rotamer_assignment[rotassign_dim0][j]; - if (respair_nenergies[i][j] == 0) { - pair_energies[rotassign_dim0][i][j] = 0; - pair_energies[rotassign_dim0][j][i] = 0; + for (int j = i + 1; j < n_res; ++j) { + int const jrot_local = rotamer_assignment[j]; + if (jrot_local == -1) { + // no need to zero out current_pair_energies here; that will occur on + // the i == this-j iteration + continue; + } + int64_t const ij_chunk_offset_offset = chunk_offset_offsets[i][j]; + if (ij_chunk_offset_offset == -1) { + // Then this pair of residues do not interact + current_pair_energies[i][j] = 0; + current_pair_energies[j][i] = 0; continue; } - int const jres_nrots = nrotamers_for_res[j]; - int const jres_nchunks = (jres_nrots - 1) / chunk_size + 1; + int const jres_n_rots = n_rotamers_for_res[j]; + int const jres_n_chunks = (jres_n_rots - 1) / chunk_size + 1; int const jrot_chunk = jrot_local / chunk_size; int const jrot_in_chunk = jrot_local - chunk_size * jrot_chunk; int const jrot_chunk_size = - std::min(chunk_size, jres_nrots - chunk_size * jrot_chunk); - - int const ij_chunk_offset_offset = chunk_offset_offsets[i][j]; - int const ij_chunk_offset = fine_chunk_offsets - [ij_chunk_offset_offset + irot_chunk * jres_nchunks + jrot_chunk]; - if (ij_chunk_offset < 0) { - pair_energies[rotassign_dim0][i][j] = 0; - pair_energies[rotassign_dim0][j][i] = 0; + min(chunk_size, jres_n_rots - chunk_size * jrot_chunk); + + // int const ij_chunk_offset_offset = chunk_offset_offsets[i][j]; + int64_t const ij_chunk_offset = + (chunk_offsets + [ij_chunk_offset_offset + irot_chunk * jres_n_chunks + + jrot_chunk]); + if (ij_chunk_offset == -1) { + current_pair_energies[i][j] = 0; + current_pair_energies[j][i] = 0; + continue; } - float ij_energy = energy2b - [twob_offsets[i][j] + ij_chunk_offset - + irot_in_chunk * jrot_chunk_size + jrot_in_chunk]; + float const ij_energy = + (energy2b + [ij_chunk_offset + irot_in_chunk * jrot_chunk_size + + jrot_in_chunk]); totalE += ij_energy; - pair_energies[rotassign_dim0][i][j] = ij_energy; - pair_energies[rotassign_dim0][j][i] = ij_energy; + current_pair_energies[i][j] = ij_energy; + current_pair_energies[j][i] = ij_energy; } } return totalE; @@ -131,60 +159,93 @@ inline #endif float total_energy_for_assignment( - TView nrotamers_for_res, - TView oneb_offsets, - TView res_for_rot, - TView respair_nenergies, - TView chunk_size_t, - TView chunk_offset_offsets, - TView twob_offsets, - TView fine_chunk_offsets, - TView energy1b, - TView energy2b, - TView rotamer_assignment, - int rotassign_dim0 // i.e. thread_id + TensorAccessor n_rotamers_for_res, // max-n-res + TensorAccessor oneb_offsets, // max-n-res + int32_t const chunk_size, + TensorAccessor + chunk_offset_offsets, // max-n-res x max-n-res + TView chunk_offsets, // n-interacting-chunk-pairs + TView energy1b, // n-rotamers-total + TView energy2b, // n-interacting-rotamer-pairs + TensorAccessor + rotamer_assignment // local rotamer indices; max-n-res ) { + // Read the energies from energ1b and energy2b for the given + // rotamer_assignment (represented as the local indices for + // each rotamer on each block) + +#ifndef __CUDACC__ + using std::min; +#endif + int const n_res = n_rotamers_for_res.size(0); + + // std::cout << "total energy for assignment:" << std::endl; + // for (int i = 0; i < n_res; ++i) { + // if (i % 30 == 29) { + // std::cout << "\n"; + // } + // std::cout << std::setw(4) << rotamer_assignment[i]; + // } + // std::cout << std::endl; + + int count_out = 0; float totalE = 0; - int const nres = nrotamers_for_res.size(0); - int const chunk_size = chunk_size_t[0]; - for (int i = 0; i < nres; ++i) { - int const irot_local = rotamer_assignment[rotassign_dim0][i]; + for (int i = 0; i < n_res; ++i) { + int const irot_local = rotamer_assignment[i]; + + if (irot_local == -1) { + // unassigned rotamer or residue off the end for the Pose + continue; + } int const irot_global = irot_local + oneb_offsets[i]; - int const ires_nrots = nrotamers_for_res[i]; - int const ires_nchunks = (ires_nrots - 1) / chunk_size + 1; + int const ires_n_rots = n_rotamers_for_res[i]; + int const ires_n_chunks = (ires_n_rots - 1) / chunk_size + 1; int const irot_chunk = irot_local / chunk_size; int const irot_in_chunk = irot_local - chunk_size * irot_chunk; int const irot_chunk_size = - std::min(chunk_size, ires_nrots - chunk_size * irot_chunk); + min(chunk_size, ires_n_rots - chunk_size * irot_chunk); totalE += energy1b[irot_global]; - for (int j = i + 1; j < nres; ++j) { - int const jrot_local = rotamer_assignment[rotassign_dim0][j]; - if (respair_nenergies[i][j] == 0) { + for (int j = i + 1; j < n_res; ++j) { + int const jrot_local = rotamer_assignment[j]; + if (jrot_local == -1) { + continue; + } + int64_t const ij_chunk_offset_offset = chunk_offset_offsets[i][j]; + if (ij_chunk_offset_offset == -1) { + // Then this pair of residues do not interact continue; } - int const jres_nrots = nrotamers_for_res[j]; - int const jres_nchunks = (jres_nrots - 1) / chunk_size + 1; + int const jres_n_rots = n_rotamers_for_res[j]; + int const jres_n_chunks = (jres_n_rots - 1) / chunk_size + 1; int const jrot_chunk = jrot_local / chunk_size; int const jrot_in_chunk = jrot_local - chunk_size * jrot_chunk; int const jrot_chunk_size = - std::min(chunk_size, jres_nrots - chunk_size * jrot_chunk); + min(chunk_size, jres_n_rots - chunk_size * jrot_chunk); - int const ij_chunk_offset_offset = chunk_offset_offsets[i][j]; - int const ij_chunk_offset = fine_chunk_offsets - [ij_chunk_offset_offset + irot_chunk * jres_nchunks + jrot_chunk]; - if (ij_chunk_offset < 0) { + // int const ij_chunk_offset_offset = chunk_offset_offsets[i][j]; + int64_t const ij_chunk_offset = + (chunk_offsets + [ij_chunk_offset_offset + irot_chunk * jres_n_chunks + + jrot_chunk]); + if (ij_chunk_offset == -1) { continue; } - int64_t index = twob_offsets[i][j] + ij_chunk_offset - + irot_in_chunk * jrot_chunk_size + jrot_in_chunk; - - // std::cout << "twob index " << index << std::endl; - float ij_energy = energy2b[index]; + float const ij_energy = + (energy2b + [ij_chunk_offset + irot_in_chunk * jrot_chunk_size + + jrot_in_chunk]); + // ++count_out; + // if (count_out % 10 == 9) { + // std::cout << "\n"; + // } + // std::cout << std::setprecision(6) << std::setw(10) << ij_energy; totalE += ij_energy; } } + // std::cout << "\n" << totalE << std::endl; + // std::cout << std::endl; return totalE; } diff --git a/tmol/pack/datatypes.py b/tmol/pack/datatypes.py index b239c3cc7..78d08a889 100644 --- a/tmol/pack/datatypes.py +++ b/tmol/pack/datatypes.py @@ -8,16 +8,15 @@ @attr.s(auto_attribs=True, frozen=True) class PackerEnergyTables(TensorGroup, ConvertAttrs): - nrotamers_for_res: Tensor[torch.int32][:] # [nres] - # nrestype_groups_for_res: Tensor[torch.int32][:] # [nres] - oneb_offsets: Tensor[torch.int32][:] # [nres] - res_for_rot: Tensor[torch.int32][:] # [nrotamers_total] - # restype_group_for_rot: Tensor[torch.int32][:] # [nrotamers_total] - respair_nenergies: Tensor[torch.int32][:, :] # [nres x nres] - chunk_size: Tensor[torch.int32][:] # [ 1 ] - chunk_offset_offsets: Tensor[torch.int32][:, :] # [nres x nres] - twob_offsets: Tensor[torch.int64][:, :] # [nres x nres] - fine_chunk_offsets: Tensor[torch.int32][:] # - # twob_fine_offsets: Tensor[torch.int64][:] # [n_nonzero_submatrices] + max_n_rotamers_per_pose: int + pose_n_res: Tensor[torch.int32][:] # [n-poses] + pose_n_rotamers: Tensor[torch.int32][:] # [n-poses] + pose_rotamer_offset: Tensor[torch.int32][:] # [n-poses] + nrotamers_for_res: Tensor[torch.int32][:, :] # [n-poses x n-res] + oneb_offsets: Tensor[torch.int32][:, :] # [n-poses x n-res] + res_for_rot: Tensor[torch.int32][:] # [n-rotamers-total] + chunk_size: int + chunk_offset_offsets: Tensor[torch.int64][:, :, :] # [n-poses x n-res x n-res] + chunk_offsets: Tensor[torch.int64][:] # [n-interacting-chunk-pairs] energy1b: Tensor[torch.float32][:] # [nrotamers_total] energy2b: Tensor[torch.float32][:] # [ntwob_energies] diff --git a/tmol/pack/impose_rotamers.py b/tmol/pack/impose_rotamers.py new file mode 100644 index 000000000..2773ec262 --- /dev/null +++ b/tmol/pack/impose_rotamers.py @@ -0,0 +1,148 @@ +import torch + +from tmol.types.torch import Tensor +from tmol.pose.pose_stack import PoseStack +from tmol.pack.rotamer.build_rotamers import RotamerSet +from tmol.utility.cumsum import exclusive_cumsum2d_w_totals + + +def impose_top_rotamer_assignments( + orig_pose_stack: PoseStack, + rotamer_set: RotamerSet, + assignment: Tensor[torch.int32][:, :, :], +): + """Impose the lowest-energy rotamer assignemnt to each pose in the original PoseStack.""" + + # Going through PoseStack's data members; what will be new, what will be unchanged + # + # -- packed_block_types: unchanged + # + # -- coords: new -- the whole point! + # + # -- block_coords_offsets: has to be updated because the number of atoms per residue may have + # changed + # + # -- inter_residue_connections: unchanged; the packer cannot change inter-block connections + # + # -- inter_block_bondsep: unchanged + # + # -- block_type_ind: new as the block types may have changed + + pbt = orig_pose_stack.packed_block_types + device = orig_pose_stack.device + n_poses = orig_pose_stack.n_poses + max_n_blocks = orig_pose_stack.max_n_blocks + max_n_atoms_per_block = orig_pose_stack.max_n_atoms + + # lets figure out how many atoms per pose + + new_block_type_ind64 = torch.full( + (n_poses, max_n_blocks), -1, dtype=torch.int64, device=device + ) + # rot_for_block = torch.zeros((n_poses, max_n_blocks), dtype=torch.int64, device=device) + new_rot_for_block64 = ( + assignment[:, 0, :].to(torch.int64) + rotamer_set.rot_offset_for_block + ) + + # print("New rot for block") + # print(new_rot_for_block64) + + is_real_block = orig_pose_stack.block_type_ind64 != -1 + + new_block_type_ind64[is_real_block] = rotamer_set.block_type_ind_for_rot[ + new_rot_for_block64[is_real_block] + ] + new_n_atoms_per_block32 = torch.zeros( + (n_poses, max_n_blocks), dtype=torch.int32, device=device + ) + new_n_atoms_per_block32[is_real_block] = pbt.n_atoms[ + new_block_type_ind64[is_real_block] + ] + new_n_atoms_per_block64 = new_n_atoms_per_block32.to(torch.int64) + + # get the per-pose offset for each block w/ exclusive cumsum on n-atoms-per-block + new_n_atoms_offset32, new_n_pose_atoms = exclusive_cumsum2d_w_totals( + new_n_atoms_per_block32 + ) + new_n_atoms_offset64 = new_n_atoms_offset32.to(torch.int64) + new_max_n_pose_atoms = int(torch.max(new_n_pose_atoms).item()) + + # okay, now lets preprare the indices for our copy operation + # let's think about it like this: we have a 3D tensor with i, j, k indices representing + # pose-ind, block-ind, and atom-ind. + # For the dst indices, we add a per-pose offset i * new_max_n_pose_atoms + # and we add a block-offset from new_n_atoms_offset64[j]. + # For the src indices, we take the rotamer assigned to pose-i-residue-j, + # and that rotamer gives us the offset into the rotamer_set.coords tensor. + max_n_atoms_arange64 = torch.arange( + max_n_atoms_per_block, dtype=torch.int64, device=device + ) + max_n_atoms_arange64 = max_n_atoms_arange64.view(1, 1, -1).expand( + n_poses, max_n_blocks, max_n_atoms_per_block + ) + + pose_for_atom64 = torch.arange(n_poses, dtype=torch.int64, device=device) + pose_for_atom64 = pose_for_atom64.view(-1, 1, 1).expand( + n_poses, max_n_blocks, max_n_atoms_per_block + ) + + pose_offset_for_atom64 = ( + torch.arange(n_poses, dtype=torch.int64, device=device) * new_max_n_pose_atoms + ) + pose_offset_for_atom64 = pose_offset_for_atom64.view(-1, 1, 1).expand( + n_poses, max_n_blocks, max_n_atoms_per_block + ) + + block_for_atom64 = ( + torch.arange(max_n_blocks, dtype=torch.int64, device=device) + .view(1, -1, 1) + .expand(n_poses, max_n_blocks, max_n_atoms_per_block) + ) + + pose_coords1d_offset_for_atom64 = ( + new_n_atoms_offset64[pose_for_atom64, block_for_atom64] + pose_offset_for_atom64 + ) + + new_n_atoms_for_atoms_block64 = new_n_atoms_per_block64.unsqueeze(2).expand( + n_poses, max_n_blocks, max_n_atoms_per_block + ) + is_pose_atom_real = max_n_atoms_arange64 < new_n_atoms_for_atoms_block64 + + dst_inds = (pose_coords1d_offset_for_atom64 + max_n_atoms_arange64)[ + is_pose_atom_real + ] + + rot_coord_offset_for_block32 = torch.full( + (n_poses, max_n_blocks), -1, dtype=torch.int32, device=device + ) + rot_coord_offset_for_block32[is_real_block] = rotamer_set.coord_offset_for_rot[ + new_rot_for_block64[is_real_block] + ] + rot_coord_offset_for_block64 = rot_coord_offset_for_block32.to(torch.int64) + rot_coord_offset_for_atom64 = rot_coord_offset_for_block64.unsqueeze(2).expand( + n_poses, max_n_blocks, max_n_atoms_per_block + ) + src_inds = (rot_coord_offset_for_atom64 + max_n_atoms_arange64)[is_pose_atom_real] + + # now lets copy the coordinates + new_coords = torch.zeros( + (n_poses * new_max_n_pose_atoms, 3), dtype=torch.float32, device=device + ) + new_coords[dst_inds] = rotamer_set.coords[src_inds] + new_coords = new_coords.view(n_poses, new_max_n_pose_atoms, 3) + + # now construct the new PoseStack + new_pose_stack = PoseStack( + packed_block_types=pbt, + coords=new_coords, + block_coord_offset=new_n_atoms_offset32, + block_coord_offset64=new_n_atoms_offset64, + inter_residue_connections=orig_pose_stack.inter_residue_connections, + inter_residue_connections64=orig_pose_stack.inter_residue_connections64, + inter_block_bondsep=orig_pose_stack.inter_block_bondsep, + inter_block_bondsep64=orig_pose_stack.inter_block_bondsep64, + block_type_ind=new_block_type_ind64.to(torch.int32), + block_type_ind64=new_block_type_ind64, + device=device, + ) + return new_pose_stack diff --git a/tmol/pack/pack_rotamers.py b/tmol/pack/pack_rotamers.py new file mode 100644 index 000000000..d1ce9eb96 --- /dev/null +++ b/tmol/pack/pack_rotamers.py @@ -0,0 +1,61 @@ +import torch + +from tmol.pose.pose_stack import PoseStack +from tmol.score.score_function import ScoreFunction + +from tmol.pack.compiled.compiled import build_interaction_graph +from tmol.pack.packer_task import PackerTask +from tmol.pack.rotamer.build_rotamers import build_rotamers +from tmol.pack.datatypes import PackerEnergyTables +from tmol.pack.simulated_annealing import run_simulated_annealing +from tmol.pack.impose_rotamers import impose_top_rotamer_assignments + + +def pack_rotamers(pose_stack: PoseStack, sfxn: ScoreFunction, task: PackerTask): + pbt = pose_stack.packed_block_types + + pose_stack, rotamer_set = build_rotamers(pose_stack, task, pbt.chem_db) + + rotamer_scoring_module = sfxn.render_rotamer_scoring_module(pose_stack, rotamer_set) + + energies = rotamer_scoring_module(rotamer_set.coords) + energies = energies.coalesce() + + chunk_size = 16 + + (energy1b, chunk_pair_offset_for_block_pair, chunk_pair_offset, energy2b) = ( + build_interaction_graph( + chunk_size, + rotamer_set.n_rots_for_pose, + rotamer_set.rot_offset_for_pose, + rotamer_set.n_rots_for_block, + rotamer_set.rot_offset_for_block, + rotamer_set.pose_for_rot, + rotamer_set.block_type_ind_for_rot, + rotamer_set.block_ind_for_rot, + energies.indices().to(torch.int32), + energies.values(), + ) + ) + + packer_energy_tables = PackerEnergyTables( + max_n_rotamers_per_pose=rotamer_set.max_n_rots_per_pose, + pose_n_res=pose_stack.n_res_per_pose, + pose_n_rotamers=rotamer_set.n_rots_for_pose, + pose_rotamer_offset=rotamer_set.rot_offset_for_pose, + nrotamers_for_res=rotamer_set.n_rots_for_block, + oneb_offsets=rotamer_set.rot_offset_for_block, + res_for_rot=rotamer_set.block_ind_for_rot, + chunk_size=chunk_size, + chunk_offset_offsets=chunk_pair_offset_for_block_pair, + chunk_offsets=chunk_pair_offset, + energy1b=energy1b, + energy2b=energy2b, + ) + + scores, rotamer_assignments = run_simulated_annealing(packer_energy_tables) + new_pose_stack = impose_top_rotamer_assignments( + pose_stack, rotamer_set, rotamer_assignments + ) + + return new_pose_stack diff --git a/tmol/pack/packer_task.py b/tmol/pack/packer_task.py index b8b639805..ada17eefa 100644 --- a/tmol/pack/packer_task.py +++ b/tmol/pack/packer_task.py @@ -1,9 +1,12 @@ +import numpy + from tmol.chemical.restypes import RefinedResidueType, ResidueTypeSet from tmol.pose.pose_stack import PoseStack +from tmol.pack.rotamer.conformer_sampler import ConformerSampler from tmol.pack.rotamer.chi_sampler import ChiSampler -# Architecture is stolen from Rosetta3: +# Architecture is borrowed from Rosetta3: # PackerTask: a class holding data describing how the # packer should behave. Each position in the # PackerTask corresponds to a residue in the input @@ -16,6 +19,10 @@ # PackerPallete: a class that decides how to construct # a PackerTask, deciding which residue types to allow # based on the residue type of the input structure. +# Different PackerPalletes will construct different +# starting points which can be refined towards the +# set of design choices that make sense for your +# application def set_compare(x, y): @@ -34,96 +41,127 @@ class PackerPalette: def __init__(self, rts: ResidueTypeSet): self.rts = rts - def restypes_from_original(self, orig: RefinedResidueType): + def block_types_from_original(self, orig: RefinedResidueType): # ok, this is where we figure out what the allowed restypes # are for a residue; this might be complex logic. keepers = [] - for rt in self.rts.residue_types: + for bt in self.rts.residue_types: if ( - rt.properties.polymer.is_polymer == orig.properties.polymer.is_polymer - and rt.properties.polymer.polymer_type + bt.properties.polymer.is_polymer == orig.properties.polymer.is_polymer + and bt.properties.polymer.polymer_type == orig.properties.polymer.polymer_type - and rt.properties.polymer.backbone_type + and bt.properties.polymer.backbone_type == orig.properties.polymer.backbone_type - and rt.connections + and bt.connections == orig.connections # fd use this instead of terminal variant check and set_compare( - rt.properties.chemical_modifications, + bt.properties.chemical_modifications, orig.properties.chemical_modifications, ) and set_compare( - rt.properties.connectivity, orig.properties.connectivity + bt.properties.connectivity, orig.properties.connectivity ) - and rt.properties.protonation.protonation_state + and bt.properties.protonation.protonation_state == orig.properties.protonation.protonation_state ): if ( - rt.properties.polymer.sidechain_chirality + bt.properties.polymer.sidechain_chirality == orig.properties.polymer.sidechain_chirality ): - keepers.append(rt) + keepers.append(bt) elif orig.properties.polymer.polymer_type == "amino_acid" and ( ( orig.properties.polymer.sidechain_chirality == "l" - and rt.properties.polymer.sidechain_chirality == "achiral" + and bt.properties.polymer.sidechain_chirality == "achiral" ) or ( orig.properties.polymer.sidechain_chirality == "achiral" - and rt.properties.polymer.sidechain_chirality == "l" + and bt.properties.polymer.sidechain_chirality == "l" ) ): # allow glycine <--> l-caa mutations - keepers.append(rt) + keepers.append(bt) elif ( orig.properties.polymer.polymer_type == "amino_acid" and orig.properties.polymer.sidechain_chirality == "d" - and rt.properties.polymer.sidechain_chirality == "achiral" + and bt.properties.polymer.sidechain_chirality == "achiral" ): # allow d-caa --> glycine mutations; # dangerous because this packer pallete will allow # your d-caa to become glycine, and then later # to an l-caa, but not the other way around - keepers.append(rt) + keepers.append(bt) return keepers + def default_conformer_samplers(self, block_type): + """All positions must build one rotamer, even if they are not being optimized. + + Each block must have coordinates represented in the tensor with the other + rotamers, and the easiest way to do that is to create a rotamer with the + DOFs of the input conformation. The IncludeCurrentSampler copies these + DOFs from the inverse-folded coordinates of the starting Pose's blocks. + Future versions of PackerPalette have the option to override this method. + """ + from tmol.pack.rotamer.include_current_sampler import ( + IncludeCurrentSampler, + ) + + return [IncludeCurrentSampler()] + -class ResidueLevelTask: +class BlockLevelTask: def __init__( - self, seqpos: int, restype: RefinedResidueType, palette: PackerPalette + self, seqpos: int, block_type: RefinedResidueType, palette: PackerPalette ): self.seqpos = seqpos - self.original_restype = restype - self.allowed_restypes = palette.restypes_from_original(restype) - self.chi_samplers = [] + self.original_block_type = block_type + self.considered_block_types = palette.block_types_from_original(block_type) + self.block_type_allowed = numpy.full( + len(self.considered_block_types), True, dtype=bool + ) + self.conformer_samplers = palette.default_conformer_samplers(block_type) + self.is_chi_sampler = [] + self.include_current = False + self.chi_expansion = numpy.zeros( + (len(self.considered_block_types), 4), dtype=numpy.int32 + ) def restrict_to_repacking(self): - orig = self.original_restype - self.allowed_restypes = [ - rt - for rt in self.allowed_restypes - if rt.name3 == orig.name3 # this isn't what we want long term - ] + orig = self.original_block_type + for i, bt in enumerate(self.considered_block_types): + if bt.name3 != orig.name3: + self.block_type_allowed[i] = False def disable_packing(self): - self.allowed_restypes = [] + # Note: we will always include at least one rotamer from every block + # in the RotamerSet, falling back on the coordinates of the starting + # block if this block-level-task is marked as kept fixed. + self.block_type_allowed[:] = False - def add_chi_sampler(self, sampler: ChiSampler): - self.chi_samplers.append(sampler) + def add_conformer_sampler(self, sampler: ConformerSampler): + self.conformer_samplers.append(sampler) + self.is_chi_sampler.append(isinstance(sampler, ChiSampler)) def restrict_absent_name3s(self, name3s): - self.allowed_restypes = [ - rt for rt in self.allowed_restypes if rt.name3 in name3s - ] + for i, bt in enumerate(self.considered_block_types): + if bt.name3 not in name3s: + self.block_type_allowed[i] = False + + def or_expand_chi(self, chi_ind: int): + self.chi_expansion[:, chi_ind] = 1 + + def or_expand_chi_to(self, chi_ind: int, sample_level: int): + self.chi_expansion[:, chi_ind] = sample_level class PackerTask: def __init__(self, systems: PoseStack, palette: PackerPalette): - self.rlts = [ + self.blts = [ [ - ResidueLevelTask(j, systems.block_type(i, j), palette) + BlockLevelTask(j, systems.block_type(i, j), palette) for j in range(systems.max_n_blocks) if systems.is_real_block(i, j) ] @@ -131,11 +169,26 @@ def __init__(self, systems: PoseStack, palette: PackerPalette): ] def restrict_to_repacking(self): - for one_pose_rlts in self.rlts: - for rlt in one_pose_rlts: - rlt.restrict_to_repacking() - - def add_chi_sampler(self, sampler: ChiSampler): - for one_pose_rlts in self.rlts: - for rlt in one_pose_rlts: - rlt.add_chi_sampler(sampler) + for one_pose_blts in self.blts: + for blt in one_pose_blts: + blt.restrict_to_repacking() + + def add_conformer_sampler(self, sampler: ConformerSampler): + for one_pose_blts in self.blts: + for blt in one_pose_blts: + blt.add_conformer_sampler(sampler) + + def set_include_current(self): + for one_pose_blts in self.blts: + for blt in one_pose_blts: + blt.include_current = True + + def or_expand_chi(self, chi_ind: int): + for one_pose_blts in self.blts: + for blt in one_pose_blts: + blt.or_expand_chi(chi_ind) + + def or_expand_chi_to(self, chi_ind: int, sample_level: int): + for one_pose_blts in self.blts: + for blt in one_pose_blts: + blt.or_expand_chi_to(chi_ind, sample_level) diff --git a/tmol/pack/rotamer/bounding_spheres.py b/tmol/pack/rotamer/bounding_spheres.py index b622bb3e5..3b2f089f0 100644 --- a/tmol/pack/rotamer/bounding_spheres.py +++ b/tmol/pack/rotamer/bounding_spheres.py @@ -1,118 +1,118 @@ -import torch - -from tmol.utility.tensor.common_operations import stretch -from tmol.pose.pose_stack import PoseStack -from tmol.pack.rotamer.build_rotamers import RotamerSet - - -def create_rotamer_bounding_spheres(poses: PoseStack, rotamer_set: RotamerSet): - torch_device = poses.device - n_poses = poses.n_poses - max_n_blocks = poses.max_n_blocks - n_rots = rotamer_set.pose_for_rot.shape[0] - - bounding_spheres = torch.full( - (n_poses, max_n_blocks, 4), 0, dtype=torch.float32, device=torch_device - ) - - # what is the center of the smallest sphere that encloses all the rotamers? - # let's just take the center of mass for the rotamers - - global_block_ind_for_rot = ( - rotamer_set.pose_for_rot * max_n_blocks - + rotamer_set.block_ind_for_rot.to(torch.int64) - ) - max_n_block_atoms = poses.packed_block_types.max_n_atoms - centers_of_mass = torch.zeros( - (n_poses * max_n_blocks, 3), dtype=torch.float32, device=torch_device - ) - centers_of_mass.index_add_( - 0, - stretch(global_block_ind_for_rot, max_n_block_atoms), - rotamer_set.coords.reshape(-1, 3), - ) - n_ats_for_rot = poses.packed_block_types.n_atoms[rotamer_set.block_type_ind_for_rot] - n_ats = torch.zeros( - (n_poses * max_n_blocks,), dtype=torch.int32, device=torch_device - ) - n_ats.index_add_(0, global_block_ind_for_rot, n_ats_for_rot) - - centers_of_mass[n_ats != 0] = centers_of_mass[n_ats != 0] / n_ats[ - n_ats != 0 - ].unsqueeze(1).to(torch.float32) - # print("centers_of_mass[:10]") - # print(centers_of_mass[:10]) - at_is_real = torch.arange( - max_n_block_atoms, dtype=torch.int32, device=torch_device - ).repeat(n_rots).reshape(n_rots, max_n_block_atoms) < n_ats_for_rot.unsqueeze(dim=1) - diff_w_com = torch.zeros_like(rotamer_set.coords) - - diff_w_com[at_is_real] = ( - centers_of_mass[stretch(global_block_ind_for_rot, max_n_block_atoms)].reshape( - n_rots, max_n_block_atoms, 3 - )[at_is_real] - - rotamer_set.coords[at_is_real] - ) - atom_dist_to_com = torch.norm(diff_w_com, dim=2) - rot_dist_to_com = torch.max(atom_dist_to_com, dim=1)[0] - - # now I need to get the max for all the rotamers at a single position, and I - # don't know how to do that except 1) segmented scan on max in c++, or - # 2) create an overly-large tensor of n-poses x max-n-blocks x max-n-rots - # and then populate that tensor with rot_dist_to_com - max_n_rots_per_block = torch.max(rotamer_set.n_rots_for_block) - rot_dist_to_com_big = torch.zeros( - (n_poses * max_n_blocks, max_n_rots_per_block), - dtype=torch.float32, - device=torch_device, - ) - - rot_dist_to_com_big[ - global_block_ind_for_rot, - ( - torch.arange(n_rots, dtype=torch.int64, device=torch_device) - - rotamer_set.rot_offset_for_block.flatten()[global_block_ind_for_rot] - ), - ] = rot_dist_to_com - sphere_radius = torch.max(rot_dist_to_com_big, dim=1)[0] - bounding_spheres = torch.zeros( - (n_poses * max_n_blocks, 4), dtype=torch.float32, device=torch_device - ) - bounding_spheres[:, 3][n_ats != 0] = sphere_radius[n_ats != 0] - bounding_spheres[:, :3][n_ats != 0] = centers_of_mass[n_ats != 0] - - # load the coordinates for the poses into a 4D tensor out of the - # 3D tensor and then we can compute center of mass by just summing - # along the 3rd dimension - expanded_coords, real_expanded_pose_ats = poses.expand_coords() - - # and we can now sum along dimension 2 - background_centers_of_mass = torch.sum(expanded_coords, dim=2).reshape(-1, 3) - pbti = poses.block_type_ind.to(torch.int64).flatten() - background_n_ats = poses.n_ats_per_block.flatten() - - background_centers_of_mass[pbti != -1] = background_centers_of_mass[ - pbti != -1 - ] / background_n_ats[pbti != -1].unsqueeze(1).to(torch.float32) - pose_diff_w_com = torch.zeros_like(expanded_coords).reshape(-1, 3) - - pose_diff_w_com[real_expanded_pose_ats.view(-1)] = background_centers_of_mass[ - stretch( - torch.arange( - n_poses * max_n_blocks, dtype=torch.int64, device=torch_device - ), - max_n_block_atoms, - ) - ][real_expanded_pose_ats.view(-1)] - expanded_coords[real_expanded_pose_ats].view( - -1, 3 - ) - pose_diff_w_com = pose_diff_w_com.view(n_poses * max_n_blocks, max_n_block_atoms, 3) - pose_dist_to_com = torch.norm(pose_diff_w_com, dim=2) - pose_bounding_radius = torch.max(pose_dist_to_com, dim=1)[0] - - bounding_spheres[:, 3][n_ats == 0] = pose_bounding_radius[n_ats == 0] - bounding_spheres[:, :3][n_ats == 0] = background_centers_of_mass.reshape(-1, 3)[ - n_ats == 0 - ] - - return bounding_spheres.view(n_poses, max_n_blocks, 4) +# import torch +# +# from tmol.utility.tensor.common_operations import stretch +# from tmol.pose.pose_stack import PoseStack +# from tmol.pack.rotamer.build_rotamers import RotamerSet +# +# +# def create_rotamer_bounding_spheres(poses: PoseStack, rotamer_set: RotamerSet): +# torch_device = poses.device +# n_poses = poses.n_poses +# max_n_blocks = poses.max_n_blocks +# n_rots = rotamer_set.pose_for_rot.shape[0] +# +# bounding_spheres = torch.full( +# (n_poses, max_n_blocks, 4), 0, dtype=torch.float32, device=torch_device +# ) +# +# # what is the center of the smallest sphere that encloses all the rotamers? +# # let's just take the center of mass for the rotamers +# +# global_block_ind_for_rot = ( +# rotamer_set.pose_for_rot * max_n_blocks +# + rotamer_set.block_ind_for_rot.to(torch.int64) +# ) +# max_n_block_atoms = poses.packed_block_types.max_n_atoms +# centers_of_mass = torch.zeros( +# (n_poses * max_n_blocks, 3), dtype=torch.float32, device=torch_device +# ) +# centers_of_mass.index_add_( +# 0, +# stretch(global_block_ind_for_rot, max_n_block_atoms), +# rotamer_set.coords.reshape(-1, 3), +# ) +# n_ats_for_rot = poses.packed_block_types.n_atoms[rotamer_set.block_type_ind_for_rot] +# n_ats = torch.zeros( +# (n_poses * max_n_blocks,), dtype=torch.int32, device=torch_device +# ) +# n_ats.index_add_(0, global_block_ind_for_rot, n_ats_for_rot) +# +# centers_of_mass[n_ats != 0] = centers_of_mass[n_ats != 0] / n_ats[ +# n_ats != 0 +# ].unsqueeze(1).to(torch.float32) +# # print("centers_of_mass[:10]") +# # print(centers_of_mass[:10]) +# at_is_real = torch.arange( +# max_n_block_atoms, dtype=torch.int32, device=torch_device +# ).repeat(n_rots).reshape(n_rots, max_n_block_atoms) < n_ats_for_rot.unsqueeze(dim=1) +# diff_w_com = torch.zeros_like(rotamer_set.coords) +# +# diff_w_com[at_is_real] = ( +# centers_of_mass[stretch(global_block_ind_for_rot, max_n_block_atoms)].reshape( +# n_rots, max_n_block_atoms, 3 +# )[at_is_real] +# - rotamer_set.coords[at_is_real] +# ) +# atom_dist_to_com = torch.norm(diff_w_com, dim=2) +# rot_dist_to_com = torch.max(atom_dist_to_com, dim=1)[0] +# +# # now I need to get the max for all the rotamers at a single position, and I +# # don't know how to do that except 1) segmented scan on max in c++, or +# # 2) create an overly-large tensor of n-poses x max-n-blocks x max-n-rots +# # and then populate that tensor with rot_dist_to_com +# max_n_rots_per_block = torch.max(rotamer_set.n_rots_for_block) +# rot_dist_to_com_big = torch.zeros( +# (n_poses * max_n_blocks, max_n_rots_per_block), +# dtype=torch.float32, +# device=torch_device, +# ) +# +# rot_dist_to_com_big[ +# global_block_ind_for_rot, +# ( +# torch.arange(n_rots, dtype=torch.int64, device=torch_device) +# - rotamer_set.rot_offset_for_block.flatten()[global_block_ind_for_rot] +# ), +# ] = rot_dist_to_com +# sphere_radius = torch.max(rot_dist_to_com_big, dim=1)[0] +# bounding_spheres = torch.zeros( +# (n_poses * max_n_blocks, 4), dtype=torch.float32, device=torch_device +# ) +# bounding_spheres[:, 3][n_ats != 0] = sphere_radius[n_ats != 0] +# bounding_spheres[:, :3][n_ats != 0] = centers_of_mass[n_ats != 0] +# +# # load the coordinates for the poses into a 4D tensor out of the +# # 3D tensor and then we can compute center of mass by just summing +# # along the 3rd dimension +# expanded_coords, real_expanded_pose_ats = poses.expand_coords() +# +# # and we can now sum along dimension 2 +# background_centers_of_mass = torch.sum(expanded_coords, dim=2).reshape(-1, 3) +# pbti = poses.block_type_ind.to(torch.int64).flatten() +# background_n_ats = poses.n_ats_per_block.flatten() +# +# background_centers_of_mass[pbti != -1] = background_centers_of_mass[ +# pbti != -1 +# ] / background_n_ats[pbti != -1].unsqueeze(1).to(torch.float32) +# pose_diff_w_com = torch.zeros_like(expanded_coords).reshape(-1, 3) +# +# pose_diff_w_com[real_expanded_pose_ats.view(-1)] = background_centers_of_mass[ +# stretch( +# torch.arange( +# n_poses * max_n_blocks, dtype=torch.int64, device=torch_device +# ), +# max_n_block_atoms, +# ) +# ][real_expanded_pose_ats.view(-1)] - expanded_coords[real_expanded_pose_ats].view( +# -1, 3 +# ) +# pose_diff_w_com = pose_diff_w_com.view(n_poses * max_n_blocks, max_n_block_atoms, 3) +# pose_dist_to_com = torch.norm(pose_diff_w_com, dim=2) +# pose_bounding_radius = torch.max(pose_dist_to_com, dim=1)[0] +# +# bounding_spheres[:, 3][n_ats == 0] = pose_bounding_radius[n_ats == 0] +# bounding_spheres[:, :3][n_ats == 0] = background_centers_of_mass.reshape(-1, 3)[ +# n_ats == 0 +# ] +# +# return bounding_spheres.view(n_poses, max_n_blocks, 4) diff --git a/tmol/pack/rotamer/build_rotamers.py b/tmol/pack/rotamer/build_rotamers.py index 307e46779..792674e3a 100644 --- a/tmol/pack/rotamer/build_rotamers.py +++ b/tmol/pack/rotamer/build_rotamers.py @@ -4,7 +4,7 @@ import torch import attr -from typing import Tuple +from typing import List, Tuple from tmol.types.array import NDArray from tmol.types.attrs import ValidateAttrs @@ -45,26 +45,47 @@ class RotamerSet(ValidateAttrs): pose_for_rot: Tensor[torch.int64][:] block_type_ind_for_rot: Tensor[torch.int64][:] block_ind_for_rot: Tensor[torch.int32][:] - coords: Tensor[torch.float32][:, :, :] + coord_offset_for_rot: Tensor[torch.int32][:] + coords: Tensor[torch.float32][:, 3] + first_rot_block_type: Tensor[torch.int64][:, :] = attr.ib() -# from tmol.system.restype import RefinedResidueType + @first_rot_block_type.default + def _block_type_for_first_rot_for_block(self): + block_type_for_first_rot_for_block = torch.full_like( + self.rot_offset_for_block, -1 + ) + does_block_type_have_rots = self.n_rots_for_block != 0 + block_type_for_first_rot_for_block[does_block_type_have_rots] = ( + self.block_type_ind_for_rot[ + self.rot_offset_for_block[does_block_type_have_rots] + ] + ) + return block_type_for_first_rot_for_block + + max_n_rots_per_pose: int = attr.ib() + + @max_n_rots_per_pose.default + def _max_n_rots_per_pose(self): + return int(torch.max(self.n_rots_for_pose).cpu().item()) + pose_ind_for_atom: Tensor[torch.int64][:] = attr.ib() -# step 1: let the dunbrack library annotate the residue types -# step 2: let the dunbrack library annotate the condensed residue types -# step 3: flatten poses -# step 4: use the chi sampler to get the chi samples for all poses -# step 5: count the number of rotamers per pose -# step 5a: including rotamers that the dunbrack sampler does not provide (e.g. gly) -# step 6: allocate a n_poses x max_n_rotamers x max_n_atoms x 3 tensor -# step 7: create (n_poses * max_n_rotamers * max_n_atoms) x 3 view of coord tensor -# step 8: create parent indexing based on start-position offset + residue-type tree data -# step 9: build kinforest -# step 10: take starting coordinates from residue roots -# step 10a: take internal dofs from mainchain atoms -# step 10b: take internal dofs for other atoms from rt icoors -# step 11: refold + @pose_ind_for_atom.default + def _pose_ind_for_atom(self): + n_atoms = self.coords.shape[0] + pifa = torch.zeros((n_atoms,), dtype=torch.int64, device=self.coords.device) + # mark the first atom for the first rotamer in each pose after pose 0 + pifa[self.coord_offset_for_rot[self.rot_offset_for_pose[1:]]] = 1 + pifa = torch.cumsum(pifa, dim=0) + return pifa + + @property + def n_rotamers_total(self): + return self.block_ind_for_rot.shape[0] + + +# from tmol.system.restype import RefinedResidueType # residue type annotations: @@ -108,23 +129,23 @@ def rebuild_poses_if_necessary( all_restypes = {} samplers = set([]) - for one_pose_rlts in task.rlts: - for rlt in one_pose_rlts: - for sampler in rlt.chi_samplers: + for one_pose_blts in task.blts: + for blt in one_pose_blts: + for sampler in blt.conformer_samplers: samplers.add(sampler) - for rt in rlt.allowed_restypes: - if id(rt) not in all_restypes: - all_restypes[id(rt)] = rt + for bt in blt.considered_block_types: + if id(bt) not in all_restypes: + all_restypes[id(bt)] = bt samplers = tuple(samplers) # rebuild the poses, perhaps, if there are residue types in the task # that are absent from the poses' PBT - pose_rts = set([id(rt) for rt in poses.packed_block_types.active_block_types]) + pose_rts = set([id(bt) for bt in poses.packed_block_types.active_block_types]) needs_rebuilding = False - for rt_id in all_restypes: - if rt_id not in pose_rts: + for bt_id in all_restypes: + if bt_id not in pose_rts: needs_rebuilding = True break @@ -134,14 +155,14 @@ def rebuild_poses_if_necessary( for j in range(poses.max_n_blocks): if not poses.is_real_block(i, j): continue - rt = poses.block_type(i, j) - if id(rt) not in all_restypes: - all_restypes[id(rt)] = rt + bt = poses.block_type(i, j) + if id(bt) not in all_restypes: + all_restypes[id(bt)] = bt pbt = PackedBlockTypes.from_restype_list( poses.packed_block_types.chem_db, poses.packed_block_types.restype_set, - [rt for rt_id, rt in all_restypes.items()], + [bt for bt_id, bt in all_restypes.items()], poses.packed_block_types.device, ) @@ -228,14 +249,14 @@ def update_scan_starts( @validate_args -def construct_scans_for_rotamers( +def construct_scans_for_conformers( pbt: PackedBlockTypes, - block_type_ind_for_rot: NDArray[numpy.int32][:], - n_atoms_for_rot: Tensor[torch.int32][:], - n_atoms_offset_for_rot: NDArray[numpy.int64][:], + block_type_ind_for_conf: NDArray[numpy.int32][:], + n_atoms_for_conf: Tensor[torch.int32][:], + n_atoms_offset_for_conf: NDArray[numpy.int64][:], ): - scanStartsStack = pbt.rotamer_kinforest.scans[block_type_ind_for_rot] - genStartsStack = pbt.rotamer_kinforest.gens[block_type_ind_for_rot] + scanStartsStack = pbt.rotamer_kinforest.scans[block_type_ind_for_conf] + genStartsStack = pbt.rotamer_kinforest.gens[block_type_ind_for_conf] atomStartsStack = numpy.swapaxes(genStartsStack[:, :, 0], 0, 1) natomsPerGen = atomStartsStack[1:, :] - atomStartsStack[:-1, :] @@ -251,7 +272,7 @@ def construct_scans_for_rotamers( ) ngenStack = numpy.swapaxes( - pbt.rotamer_kinforest.n_scans_per_gen[block_type_ind_for_rot], 0, 1 + pbt.rotamer_kinforest.n_scans_per_gen[block_type_ind_for_conf], 0, 1 ) ngenStack[ngenStack < 0] = 0 ngenStackCumsum = numpy.cumsum(ngenStack.reshape(-1), axis=0) @@ -266,15 +287,21 @@ def construct_scans_for_rotamers( ngenStack, ) - nodes_orig = pbt.rotamer_kinforest.nodes[block_type_ind_for_rot].ravel() + nodes_orig = pbt.rotamer_kinforest.nodes[block_type_ind_for_conf].ravel() nodes_orig = nodes_orig[nodes_orig >= 0] + # print("nodes_orig") + # print(nodes_orig) - n_nodes_for_rot = pbt.rotamer_kinforest.n_nodes[block_type_ind_for_rot] - first_node_for_rot = numpy.cumsum(n_nodes_for_rot) - n_nodes_offset_for_rot = exc_cumsum_from_inc_cumsum(first_node_for_rot) + n_nodes_for_conf = pbt.rotamer_kinforest.n_nodes[block_type_ind_for_conf] + first_node_for_conf = numpy.cumsum(n_nodes_for_conf) + n_nodes_offset_for_conf = exc_cumsum_from_inc_cumsum(first_node_for_conf) + # print("n_nodes_offset_for_rot") + # print(n_nodes_offset_for_rot) + # print("n_atoms_offset_for_rot") + # print(n_atoms_offset_for_rot) nodes = update_nodes( - nodes_orig, genStartsStack, n_nodes_offset_for_rot, n_atoms_offset_for_rot + nodes_orig, genStartsStack, n_nodes_offset_for_conf, n_atoms_offset_for_conf ) gen_starts = numpy.sum(genStartsStack, axis=0) @@ -335,69 +362,75 @@ def load_rotamer_parents( @validate_args -def construct_kinforest_for_rotamers( +def construct_kinforest_for_conformers( pbt: PackedBlockTypes, - rot_block_type_ind: NDArray[numpy.int32][:], + conf_block_type_ind: NDArray[numpy.int32][:], n_atoms_total: int, - n_atoms_for_rot: Tensor[torch.int32][:], - block_offset_for_rot: NDArray[numpy.int32][:], + n_atoms_for_conf: Tensor[torch.int32][:], + block_offset_for_conf: NDArray[numpy.int64][:], device: torch.device, ): - """Construct a KinForest for a set of rotamers by stringing - together the kinforest data for individual rotamers. - The "block_ofset_for_rot" array is used to construct + """Construct a KinForest for a set of conformers by stringing + together the kinforest data for individual conformers. + The "block_ofset_for_conf" array is used to construct the "id" tensor in the KinForest, which maps to the atom indices; thus it should contain the atom-index offsets for the first atom in each rotamer in the coords tensor that will be used to construct the kinforest_coords tensor. """ - n_atoms_for_rot = n_atoms_for_rot.cpu().numpy() + n_atoms_for_conf = n_atoms_for_conf.cpu().numpy() # append a 1 for the root node and then treat # the resulting (inclusive) scan as if it # represents offsets - temp = numpy.concatenate((numpy.ones(1, dtype=numpy.int32), n_atoms_for_rot)) - n_atoms_offset_for_rot = numpy.cumsum(temp) + temp = numpy.concatenate((numpy.ones(1, dtype=numpy.int32), n_atoms_for_conf)) + n_atoms_offset_for_conf = numpy.cumsum(temp) def nab(func, arr): - return func(arr[rot_block_type_ind], n_atoms_total, n_atoms_for_rot) + return func(arr[conf_block_type_ind], n_atoms_total, n_atoms_for_conf) - def nab2(func, arr, rot_offset): - return func(arr[rot_block_type_ind], n_atoms_total, n_atoms_for_rot, rot_offset) + def nab2(func, arr, conf_offset): + return func( + arr[conf_block_type_ind], n_atoms_total, n_atoms_for_conf, conf_offset + ) def _t(arr): return torch.tensor(arr, dtype=torch.int32, device=device) id = _t( nab2( - load_from_rotamers_w_offsets, pbt.rotamer_kinforest.id, block_offset_for_rot + load_from_rotamers_w_offsets, + pbt.rotamer_kinforest.id, + block_offset_for_conf, ) ) id[0] = -1 doftype = _t(nab(load_from_rotamers, pbt.rotamer_kinforest.doftype)) parent = _t( - nab2(load_rotamer_parents, pbt.rotamer_kinforest.parent, n_atoms_offset_for_rot) + nab2( + load_rotamer_parents, pbt.rotamer_kinforest.parent, n_atoms_offset_for_conf + ) ) frame_x = _t( nab2( load_from_rotamers_w_offsets, pbt.rotamer_kinforest.frame_x, - n_atoms_offset_for_rot, + n_atoms_offset_for_conf, ) ) frame_y = _t( nab2( load_from_rotamers_w_offsets, pbt.rotamer_kinforest.frame_y, - n_atoms_offset_for_rot, + n_atoms_offset_for_conf, ) ) frame_z = _t( nab2( load_from_rotamers_w_offsets, pbt.rotamer_kinforest.frame_z, - n_atoms_offset_for_rot, + n_atoms_offset_for_conf, ) ) @@ -416,6 +449,11 @@ def measure_dofs_from_orig_coords( ): from tmol.kinematics.compiled.compiled_inverse_kin import inverse_kin + # print("coords") + # print(coords.shape) + # print("kinforest.id") + # print(kinforest.id) + kinforest_coords = coords.view(-1, 3)[kinforest.id.to(torch.int64)] kinforest_coords[0, :] = 0 # reset root @@ -437,27 +475,21 @@ def measure_pose_dofs(poses): pbt = poses.packed_block_types pbti = poses.block_type_ind.view(-1) - orig_res_block_type_ind = pbti[pbti != -1] real_poses_blocks = pbti != -1 + orig_res_block_type_ind = pbti[real_poses_blocks] - # old coordinate layout: n-poses x max-n-res x max-n-atoms x 3 - # nz_real_poses_blocks = torch.nonzero(real_poses_blocks).flatten() - # orig_atom_offset_for_poses_blocks = ( - # nz_real_poses_blocks.cpu().numpy().astype(numpy.int32) * pbt.max_n_atoms - # ) - - # new coordinate layout: n-poses x max-n-atoms-per-pose x 3 + # coordinate layout: n-poses x max-n-atoms-per-pose x 3 # offsets provided by the pose stack n_poses = poses.coords.shape[0] - max_n_atoms_per_pose = poses.coords.shape[1] - max_n_blocks_per_pose = poses.block_coord_offset.shape[1] + max_n_atoms_per_pose = poses.max_n_pose_atoms + max_n_blocks_per_pose = poses.max_n_blocks per_pose_offset = max_n_atoms_per_pose * stretch( - torch.arange(n_poses, dtype=torch.int32, device=poses.device), + torch.arange(n_poses, dtype=torch.int64, device=poses.device), max_n_blocks_per_pose, ) orig_atom_offset_for_poses_blocks = ( ( - poses.block_coord_offset.flatten()[real_poses_blocks] + poses.block_coord_offset.flatten()[real_poses_blocks].to(torch.int64) + per_pose_offset[real_poses_blocks] ) .cpu() @@ -469,7 +501,8 @@ def measure_pose_dofs(poses): n_atoms_offset_for_orig = n_atoms_offset_for_orig.cpu().numpy() n_orig_atoms_total = n_atoms_offset_for_orig[-1] - orig_kinforest = construct_kinforest_for_rotamers( + # print("orig_atom_offset_for_poses_blocks", orig_atom_offset_for_poses_blocks.dtype) + orig_kinforest = construct_kinforest_for_conformers( poses.packed_block_types, orig_res_block_type_ind.cpu().numpy(), int(n_orig_atoms_total), @@ -479,10 +512,29 @@ def measure_pose_dofs(poses): ) # orig_dofs returned in kinforest order - return measure_dofs_from_orig_coords(poses.coords, orig_kinforest) - - -def merge_chi_samples(chi_samples): + return orig_kinforest, measure_dofs_from_orig_coords(poses.coords, orig_kinforest) + + +def merge_conformer_samples( + conformer_samples, +) -> Tuple[ + Tensor[torch.int64][:], + Tensor[torch.int64][:], + Tensor[torch.int64][:], + List[Tensor[torch.bool][:]], + List[Tensor[torch.int64][:]], +]: + """Merge the lists of conformers as described by different conformer samplers. + + The conformer_samples variable is a list of tuples: + - elem 0: Tensor[int][:] <-- the number of rotamers for each pose for each block for each block type + where each buildable block type for each real residue is given a global index + - elem 1: Tensor[int][:] <-- the global block-type index for each rotamer + - elem 2+: Extra data that the chi sampler needs to preserve, where the first dimension + is rotamer index based on elem 1's rotamer indices; the mapping from orig rotamer indices + to merged rotamer indices will be constructed by this routine + """ + # deprecated notes: # chi_samples # 0. n_rots_for_rt # 1. rt_for_rotamer @@ -490,343 +542,126 @@ def merge_chi_samples(chi_samples): # 3. chi_for_rotamers # everything needs to be on the same device - for samples in chi_samples: - for i in range(1, len(samples)): - assert samples[0].device == samples[i].device - assert chi_samples[0][0].device == samples[0].device - - device = chi_samples[0][0].device - - rt_nrot_offsets = [] - for samples in chi_samples: - rt_nrot_offsets.append(exclusive_cumsum1d(samples[0]).to(torch.int64)) - - all_rt_for_rotamer_unsorted = torch.cat([samples[1] for samples in chi_samples]) - n_rotamers = all_rt_for_rotamer_unsorted.shape[0] - max_n_rotamers_per_rt = max(torch.max(samples[0]).item() for samples in chi_samples) - - for i, samples in enumerate(chi_samples): - rot_counter_for_rt = ( - torch.arange(samples[1].shape[0], dtype=torch.int64, device=device) - - rt_nrot_offsets[i][samples[1].to(torch.int64)] - ) - numpy.testing.assert_array_less( - rot_counter_for_rt.cpu().numpy(), max_n_rotamers_per_rt - ) + torch.set_printoptions(threshold=10000) + for samples in conformer_samples: + assert samples[0].device == samples[1].device + assert conformer_samples[0][0].device == samples[0].device + # print("samples", samples[0].shape, samples[1].shape) + # print("samples[0]") + # print(samples[0]) + # print("samples[1]") + # print(samples[1]) + + device = conformer_samples[0][0].device + + # pre-merge offsets for each gbt in the set of conformers from the same sampler + gbt_n_rot_offsets = [] # formerly rt_nrot_offsets + for samples in conformer_samples: + gbt_n_rot_offsets.append(exclusive_cumsum1d(samples[0]).to(torch.int64)) + + all_gbt_for_conformer_unsorted = torch.cat( + [samples[1] for samples in conformer_samples] + ) + max_n_conformers_per_gbt_per_sampler = max( + torch.max(samples[0]).item() for samples in conformer_samples + ) - sort_rt_for_rotamer = torch.cat( + # for i, samples in enumerate(conformer_samples): + # conf_counter_for_gbt = ( + # torch.arange(samples[1].shape[0], dtype=torch.int64, device=device) + # - rt_nrot_offsets[i][samples[1].to(torch.int64)] + # ) + # numpy.testing.assert_array_less( + # rot_counter_for_rt.cpu().numpy(), max_n_rotamers_per_rt + # ) + + # create an "index" for each conformer on each GBT + # so that we can sort these indices and come up with an ordering + # of all of the conformers that will group all of the conformers + # belonging to a single GBT into a contiguous segment; + # This is accomplished by "spreading out" all of the rotamers for a single GBT + # by the maximum possible number of rotamers that could be built for any one + # GBT (i.e. n-samplers x max-n-confs-per-gbt-per-sampler x gbt-index), + # then finding which block of conformers for the given sampler + # (i.e. max-n-confs-per-gbt-per-sampler * sampler-index), + # and finally, incrementing each individual sample by its position in the + # list of rotamers for that GBT, which is readily computed as + # arange(sampler_n_rots) - gbt_n_rot_offsets[gbt_index] + # and note that gbt_index is what's stored in samples[1] + n_conformer_samplers = len(conformer_samples) + sort_index_for_conformer = torch.cat( [ - samples[1].to(torch.int64) * len(chi_samples) * max_n_rotamers_per_rt - + i * max_n_rotamers_per_rt + samples[1].to(torch.int64) + * n_conformer_samplers + * max_n_conformers_per_gbt_per_sampler + + i * max_n_conformers_per_gbt_per_sampler + torch.arange(samples[1].shape[0], dtype=torch.int64, device=device) - - rt_nrot_offsets[i][samples[1].to(torch.int64)] - for i, samples in enumerate(chi_samples) + - gbt_n_rot_offsets[i][samples[1].to(torch.int64)] + for i, samples in enumerate(conformer_samples) ] ) - sampler_for_rotamer_unsorted = torch.cat( + # temp + # torch.set_printoptions(threshold=10000) + # print("sort_index_for_conformer") + # print(sort_index_for_conformer) + + sampler_for_conformer_unsorted = torch.cat( [ torch.full((samples[1].shape[0],), i, dtype=torch.int64, device=device) - for i, samples in enumerate(chi_samples) + for i, samples in enumerate(conformer_samples) ] ) - sort_ind_for_rotamer = torch.argsort(sort_rt_for_rotamer) - sort_rt_for_rotamer_sorted = sort_rt_for_rotamer[sort_ind_for_rotamer] - uniq_sort_rt_for_rotamer = torch.unique(sort_rt_for_rotamer_sorted) - - assert uniq_sort_rt_for_rotamer.shape[0] == sort_rt_for_rotamer_sorted.shape[0] - - sampler_for_rotamer = sampler_for_rotamer_unsorted[sort_ind_for_rotamer] - - all_rt_for_rotamer = torch.cat([samples[1] for samples in chi_samples])[ - sort_ind_for_rotamer - ] - - max_n_chi_atoms = max(samples[2].shape[1] for samples in chi_samples) - all_chi_atoms = torch.full( - (n_rotamers, max_n_chi_atoms), -1, dtype=torch.int32, device=device - ) - all_chi = torch.full( - (n_rotamers, max_n_chi_atoms), -1, dtype=torch.float32, device=device - ) - offset = 0 - for samples in chi_samples: - assert samples[2].shape[0] == samples[3].shape[0] - all_chi_atoms[ - offset : (offset + samples[2].shape[0]), : samples[2].shape[1] - ] = samples[2] - all_chi[offset : (offset + samples[2].shape[0]), : samples[3].shape[1]] = ( - samples[3] - ) - offset += samples[2].shape[0] - - all_chi_atoms = all_chi_atoms[sort_ind_for_rotamer] - all_chi = all_chi[sort_ind_for_rotamer] - - # ok, now we need to figure out how many rotamers each rt is getting. - n_rots_for_rt = toolz.reduce(torch.add, [samples[0] for samples in chi_samples]) - - return ( - n_rots_for_rt, - sampler_for_rotamer, - all_rt_for_rotamer, - all_chi_atoms, - all_chi, - ) - - -@validate_args -def create_dof_inds_to_copy_from_orig_to_rotamers( - poses: PoseStack, - task: PackerTask, - samplers, # : Tuple[ChiSampler, ...], - rt_for_rot: Tensor[torch.int64][:], - block_type_ind_for_rot: Tensor[torch.int64][:], - sampler_for_rotamer: Tensor[torch.int64][:], - n_dof_atoms_offset_for_rot: Tensor[torch.int32][:], -) -> Tuple[Tensor[torch.int64][:], Tensor[torch.int64][:]]: - # we want to copy from the orig_dofs tensor into the - # rot_dofs tensor for the "mainchain" atoms in the - # original residues into the appropriate positions - # for the rotamers thta we are building at those - # residues. This requires a good deal of reindexing. - - pbt = poses.packed_block_types - n_rots = n_dof_atoms_offset_for_rot.shape[0] - - sampler_ind_mapping = torch.tensor( - [ - ( - pbt.mc_fingerprints.sampler_mapping[sampler.sampler_name()] - if sampler.sampler_name() in pbt.mc_fingerprints.sampler_mapping - else -1 - ) - for sampler in samplers - ], - dtype=torch.int64, - device=poses.device, - ) - - sampler_ind_for_rot = sampler_ind_mapping[sampler_for_rotamer] - orig_block_type_ind = ( - poses.block_type_ind[poses.block_type_ind != -1].view(-1).to(torch.int64) - ) - - poses_res_to_real_poses_res = torch.full( - (poses.block_type_ind.shape[0] * poses.block_type_ind.shape[1],), - -1, - dtype=torch.int64, - device=poses.device, - ) - poses_res_to_real_poses_res[poses.block_type_ind.view(-1) != -1] = torch.arange( - orig_block_type_ind.shape[0], dtype=torch.int64, device=poses.device - ) - - # get the residue index for each rotamer - max_n_blocks = poses.block_coord_offset.shape[1] - res_ind_for_rt = torch.tensor( - [ - i * max_n_blocks + j - for i, one_pose_rlts in enumerate(task.rlts) - for j, rlt in enumerate(one_pose_rlts) - for _ in rlt.allowed_restypes - ], - dtype=torch.int64, - device=poses.device, - ) - real_res_ind_for_rot = poses_res_to_real_poses_res[res_ind_for_rt[rt_for_rot]] - - # look up which mainchain fingerprint each - # original residue should use - - mcfp = pbt.mc_fingerprints - - sampler_ind_for_orig = mcfp.max_sampler[orig_block_type_ind] - orig_res_mcfp = mcfp.max_fingerprint[orig_block_type_ind] - orig_res_mcfp_for_rot = orig_res_mcfp[real_res_ind_for_rot] - - # now lets find the kinforest-ordered indices of the - # mainchain atoms for the rotamers that represents - # the destination for the dofs we're copying - max_n_mcfp_atoms = mcfp.atom_mapping.shape[3] - - rot_mcfp_at_inds_rto = mcfp.atom_mapping[ - sampler_ind_for_rot, orig_res_mcfp_for_rot, block_type_ind_for_rot, : - ].view(-1) - - real_rot_mcfp_at_inds_rto = rot_mcfp_at_inds_rto[rot_mcfp_at_inds_rto != -1] - - real_rot_block_type_ind_for_mcfp_ats = stretch( - block_type_ind_for_rot, max_n_mcfp_atoms - )[rot_mcfp_at_inds_rto != -1] - - rot_mcfp_at_inds_kto = torch.full_like(rot_mcfp_at_inds_rto, -1) - rot_mcfp_at_inds_kto[rot_mcfp_at_inds_rto != -1] = torch.tensor( - pbt.rotamer_kinforest.kinforest_idx[ - real_rot_block_type_ind_for_mcfp_ats.cpu().numpy(), - real_rot_mcfp_at_inds_rto.cpu().numpy(), - ], - dtype=torch.int64, - device=pbt.device, - ) - - rot_mcfp_at_inds_kto[rot_mcfp_at_inds_kto != -1] += n_dof_atoms_offset_for_rot[ - torch.div( - torch.arange( - n_rots * max_n_mcfp_atoms, dtype=torch.int64, device=poses.device - ), - max_n_mcfp_atoms, - rounding_mode="trunc", - )[rot_mcfp_at_inds_kto != -1] - ].to(torch.int64) - - # now get the indices in the orig_dofs array for the atoms to copy from. - # The steps: - # 1. get the mainchain atom indices for each of the original residues - # in residue-type order (rto) - # 2. sample 1. for each rotamer - # 3. find the real subset of these atoms - # 4. note the residue index for each of these real atoms - # 5. remap these to kinforest order (kto) - # 6. increment the indices with the original-residue dof-index offsets - - # orig_mcfp_at_inds_for_orig_rto: - # 1. these are the mainchain fingerprint atoms from the original - # residues on the pose - # 2. they are stored in residue-type order (rto) - # 3. they are indexed by original residue index - - orig_mcfp_at_inds_rto = mcfp.atom_mapping[ - sampler_ind_for_orig, orig_res_mcfp, orig_block_type_ind, : - ].view(-1) - - real_orig_block_type_ind_for_orig_mcfp_ats = stretch( - orig_block_type_ind, max_n_mcfp_atoms - )[orig_mcfp_at_inds_rto != -1] - - orig_dof_atom_offset = exclusive_cumsum1d(pbt.n_atoms[orig_block_type_ind]).to( - torch.int64 - ) - - orig_mcfp_at_inds_kto = torch.full_like(orig_mcfp_at_inds_rto, -1) - orig_mcfp_at_inds_kto[orig_mcfp_at_inds_rto != -1] = ( - torch.tensor( - pbt.rotamer_kinforest.kinforest_idx[ - real_orig_block_type_ind_for_orig_mcfp_ats.cpu().numpy(), - orig_mcfp_at_inds_rto[orig_mcfp_at_inds_rto != -1].cpu().numpy(), - ], - dtype=torch.int64, - device=pbt.device, - ) - + orig_dof_atom_offset[ - torch.floor_divide( - torch.arange( - orig_block_type_ind.shape[0] * max_n_mcfp_atoms, - dtype=torch.int64, - device=pbt.device, - ), - max_n_mcfp_atoms, - ) - ][orig_mcfp_at_inds_rto != -1] - ) - - orig_mcfp_at_inds_kto = orig_mcfp_at_inds_kto.view( - orig_block_type_ind.shape[0], max_n_mcfp_atoms - ) + argsort_ind_for_conformer = torch.argsort(sort_index_for_conformer) - orig_mcfp_at_inds_for_rot_kto = orig_mcfp_at_inds_kto[real_res_ind_for_rot, :].view( - -1 + # testing: remove this, probably + sort_ind_for_conformer_sorted = sort_index_for_conformer[argsort_ind_for_conformer] + uniq_sort_ind_for_conformer = torch.unique(sort_ind_for_conformer_sorted) + assert ( + uniq_sort_ind_for_conformer.shape[0] == sort_ind_for_conformer_sorted.shape[0] ) - # pare down the subset to those where the mc atom is present for - # both the original block type and the alternate block type; - # take the subset and also increment the indices of all the atoms - # by one to take into account the virtual root atom at the origin - - both_present = torch.logical_and( - rot_mcfp_at_inds_kto != -1, orig_mcfp_at_inds_for_rot_kto != -1 - ) - - rot_mcfp_at_inds_kto = rot_mcfp_at_inds_kto[both_present] + 1 - orig_mcfp_at_inds_for_rot_kto = orig_mcfp_at_inds_for_rot_kto[both_present] + 1 - - return rot_mcfp_at_inds_kto, orig_mcfp_at_inds_for_rot_kto - - -@validate_args -def copy_dofs_from_orig_to_rotamers( - poses: PoseStack, - task: PackerTask, - samplers, # : Tuple[ChiSampler, ...], - rt_for_rot: Tensor[torch.int64][:], - block_type_ind_for_rot: Tensor[torch.int64][:], - sampler_for_rotamer: Tensor[torch.int64][:], - n_dof_atoms_offset_for_rot: Tensor[torch.int32][:], - orig_dofs_kto: Tensor[torch.float32][:, 9], - rot_dofs_kto: Tensor[torch.float32][:, 9], -): - dst, src = create_dof_inds_to_copy_from_orig_to_rotamers( - poses, - task, - samplers, - rt_for_rot, - block_type_ind_for_rot, - sampler_for_rotamer, - n_dof_atoms_offset_for_rot, - ) - - rot_dofs_kto[dst, :] = orig_dofs_kto[src, :] - + sampler_for_conformer = sampler_for_conformer_unsorted[argsort_ind_for_conformer] + # print("sampler_for_conformer") + # print(sampler_for_conformer) -@validate_args -def assign_dofs_from_samples( - pbt: PackedBlockTypes, - rt_for_rot: Tensor[torch.int64][:], - block_type_ind_for_rot: Tensor[torch.int64][:], - chi_atoms: Tensor[torch.int32][:, :], - chi: Tensor[torch.float32][:, :], - rot_dofs_kto: Tensor[torch.float32][:, 9], -): - assert chi_atoms.shape == chi.shape - assert rt_for_rot.shape[0] == block_type_ind_for_rot.shape[0] - assert rt_for_rot.shape[0] == chi_atoms.shape[0] - - n_atoms = pbt.n_atoms[block_type_ind_for_rot] - n_rots = rt_for_rot.shape[0] - - atom_offset_for_rot = exclusive_cumsum1d(n_atoms) + # list of boolean tensors for each of the samplers: did you build the given rotamer + conformer_built_by_sampler = [ + sampler_for_conformer == i for i in range(n_conformer_samplers) + ] + # list of index tensors reporting the final index of the conformers built by the samplers + new_ind_for_sampler_rotamer = [ + torch.nonzero(built_by_sampler, as_tuple=True)[0] + for built_by_sampler in conformer_built_by_sampler + ] - max_n_chi_atoms = chi_atoms.shape[1] - real_atoms = chi_atoms.view(-1) != -1 + # for i, new_inds in enumerate(new_ind_for_sampler_rotamer): + # print("i", i, "new_inds") + # print(new_inds) - rot_ind_for_real_atom = torch.floor_divide( - torch.arange(max_n_chi_atoms * n_rots, dtype=torch.int64, device=pbt.device), - max_n_chi_atoms, - )[real_atoms] + all_gbt_for_conformer_sorted = all_gbt_for_conformer_unsorted[ + argsort_ind_for_conformer + ] + # print("all_gbt_for_conformer_sorted") + # print(all_gbt_for_conformer_sorted) - block_type_ind_for_rot_atom = ( - block_type_ind_for_rot[rot_ind_for_real_atom].cpu().numpy() + # ok, now we need to figure out how many rotamers each gbt is getting. + n_rots_for_gbt = toolz.reduce( + torch.add, [samples[0] for samples in conformer_samples] ) - rot_chi_atoms_kto = torch.tensor( - pbt.rotamer_kinforest.kinforest_idx[ - block_type_ind_for_rot_atom, chi_atoms.view(-1)[real_atoms].cpu().numpy() - ], - dtype=torch.int64, - device=pbt.device, + return ( + n_rots_for_gbt, + sampler_for_conformer, + all_gbt_for_conformer_sorted, + conformer_built_by_sampler, + new_ind_for_sampler_rotamer, ) - # increment with the atom offsets for the source rotamer and by - # one to include the virtual root - - rot_chi_atoms_kto += atom_offset_for_rot[rot_ind_for_real_atom].to(torch.int64) + 1 - - # overwrite the "downstream torsion" for the atoms that control - # each chi - rot_dofs_kto[rot_chi_atoms_kto, 3] = chi.view(-1)[real_atoms] def calculate_rotamer_coords( pbt: PackedBlockTypes, n_rots: int, + n_atoms_total: int, rot_kinforest: KinForest, nodes: NDArray[numpy.int32][:], scans: NDArray[numpy.int32][:], @@ -856,51 +691,60 @@ def _tcpu(t): ).to(pbt.device) ) + # temp + # n_atoms = 12765 + # print("rot_dofs_kto[:50]", rot_dofs_kto[:50]) + # print("rot_dofs_kto[(n_atoms-50):(n_atoms+50)]", rot_dofs_kto[(n_atoms-50):(n_atoms+50)]) + new_coords_kto = forward_only_op( rot_dofs_kto, _p(_t(nodes)), _p(_t(scans)), _p(_tcpu(gens)), kinforest_stack ) new_coords_rto = torch.zeros( - (n_rots * pbt.max_n_atoms, 3), dtype=torch.float32, device=pbt.device + (n_atoms_total, 3), dtype=torch.float32, device=pbt.device ) + # torch.set_printoptions(threshold=100000) + # print("id") + # print(rot_kinforest.id) - new_coords_rto[rot_kinforest.id.to(torch.int64)] = new_coords_kto - new_coords_rto = new_coords_rto.view(n_rots, pbt.max_n_atoms, 3) + new_coords_rto[rot_kinforest.id[1:].to(torch.int64)] = new_coords_kto[1:] + # new_coords_rto = new_coords_rto.view(n_rots, pbt.max_n_atoms, 3) + # print("new_coords_rto.shape", new_coords_rto.shape) return new_coords_rto -def get_rotamer_origin_data(task: PackerTask, rt_for_rot: Tensor[torch.int32][:]): - n_poses = len(task.rlts) - pose_for_rt = torch.tensor( +def get_rotamer_origin_data(task: PackerTask, gbt_for_rot: Tensor[torch.int32][:]): + n_poses = len(task.blts) + pose_for_gbt = torch.tensor( [ i - for i, one_pose_rlts in enumerate(task.rlts) - for rlts in one_pose_rlts - for rlt in rlts.allowed_restypes + for i, one_pose_blts in enumerate(task.blts) + for blts in one_pose_blts + for blt in blts.considered_block_types ], dtype=torch.int32, - device=rt_for_rot.device, + device=gbt_for_rot.device, ) block_ind_for_rt = torch.tensor( [ j - for one_pose_rlts in task.rlts - for j, rlts in enumerate(one_pose_rlts) - for rlt in rlts.allowed_restypes + for one_pose_blts in task.blts + for j, blts in enumerate(one_pose_blts) + for blt in blts.considered_block_types ], dtype=torch.int32, - device=rt_for_rot.device, + device=gbt_for_rot.device, ) - max_n_blocks = max(len(one_pose_rlts) for one_pose_rlts in task.rlts) + max_n_blocks = max(len(one_pose_blts) for one_pose_blts in task.blts) - rt_for_rot64 = rt_for_rot.to(torch.int64) - pose_for_rot = pose_for_rt[rt_for_rot64].to(torch.int64) - n_rots_for_pose = torch.bincount(pose_for_rot, minlength=len(task.rlts)) + gbt_for_rot64 = gbt_for_rot.to(torch.int64) + pose_for_rot = pose_for_gbt[gbt_for_rot64].to(torch.int64) + n_rots_for_pose = torch.bincount(pose_for_rot, minlength=len(task.blts)) rot_offset_for_pose = exclusive_cumsum1d(n_rots_for_pose) - block_ind_for_rot = block_ind_for_rt[rt_for_rot64] - block_ind_for_rt_global = max_n_blocks * pose_for_rt + block_ind_for_rt - block_ind_for_rot_global = block_ind_for_rt_global[rt_for_rot64] + block_ind_for_rot = block_ind_for_rt[gbt_for_rot64] + block_ind_for_rt_global = max_n_blocks * pose_for_gbt + block_ind_for_rt + block_ind_for_rot_global = block_ind_for_rt_global[gbt_for_rot64] n_rots_for_block = torch.bincount( block_ind_for_rot_global, minlength=n_poses * max_n_blocks ).reshape(n_poses, max_n_blocks) @@ -918,98 +762,357 @@ def get_rotamer_origin_data(task: PackerTask, rt_for_rot: Tensor[torch.int32][:] ) +# def build_rotamers(poses: PoseStack, task: PackerTask, chem_db: ChemicalDatabase): +# # step 0: replace the existing PBT in the Pose w/ a new one in case +# # there will possibly be new block types in the repacked Pose; +# # but since PoseStack should not be altered after construction, +# # what this really means is build an entirely new PoseStack +# # step 1: let the dunbrack library annotate the block types +# # step 2: let the dunbrack library annotate the packed block types +# # step 3: flatten poses +# # step 4: use the chi sampler to get the chi samples for all poses +# # step 5: count the number of rotamers per pose +# # step 5a: including rotamers that the dunbrack sampler does not provide (e.g. gly) +# # step 6: allocate a n_poses x max_n_rotamers x max_n_atoms x 3 tensor +# # step 7: create (n_poses * max_n_rotamers * max_n_atoms) x 3 view of coord tensor +# # step 8: create parent indexing based on start-position offset + residue-type tree data +# # step 9: build kinforest +# # step 10: take starting coordinates from residue roots +# # step 10a: take internal dofs from mainchain atoms +# # step 10b: take internal dofs for other atoms from rt icoors +# # step 11: refold +# +# poses, samplers = rebuild_poses_if_necessary(poses, task) +# pbt = poses.packed_block_types +# annotate_everything(chem_db, samplers, pbt) +# +# rt_names = [ +# rt.name +# for one_pose_blts in task.blts +# for blt in one_pose_blts +# for rt in blt.allowed_blocktypes +# ] +# # rt_block_type_ind: a mapping from the list of all block types at all +# # residues across all poses to the PBT-block-type index. We will use the +# # rt_block_type_ind as a way to refer to a particular block in a particular pose +# # as well as a particular block-type for that block. +# rt_block_type_ind = pbt.restype_index.get_indexer(rt_names).astype(numpy.int32) +# +# chi_samples = [sampler.sample_chi_for_poses(poses, task) for sampler in samplers] +# merged_samples = merge_chi_samples(chi_samples) +# n_rots_for_rt, sampler_for_rotamer, rt_for_rotamer, chi_atoms, chi = merged_samples +# +# # fd NOTE: THIS CODE FAILS IF n_rots_for_rt CONTAINS 0s +# assert 0 not in n_rots_for_rt +# +# n_rots = chi_atoms.shape[0] +# rt_for_rot = torch.zeros(n_rots, dtype=torch.int64, device=poses.device) +# n_rots_for_rt_cumsum = torch.cumsum(n_rots_for_rt, dim=0) +# rt_for_rot[n_rots_for_rt_cumsum[:-1]] = 1 +# rt_for_rot = torch.cumsum(rt_for_rot, dim=0).cpu().numpy() +# +# block_type_ind_for_rot = rt_block_type_ind[rt_for_rot] +# block_type_ind_for_rot_torch = torch.tensor( +# block_type_ind_for_rot, dtype=torch.int64, device=pbt.device +# ) +# n_atoms_for_rot = pbt.n_atoms[block_type_ind_for_rot_torch] +# n_atoms_offset_for_rot = torch.cumsum(n_atoms_for_rot, dim=0) +# n_atoms_offset_for_rot = n_atoms_offset_for_rot.cpu().numpy() +# n_atoms_total = n_atoms_offset_for_rot[-1] +# n_atoms_offset_for_rot = exc_cumsum_from_inc_cumsum(n_atoms_offset_for_rot) +# +# rot_kinforest = construct_kinforest_for_rotamers( +# pbt, +# block_type_ind_for_rot, +# int(n_atoms_total), +# torch.tensor(n_atoms_for_rot, dtype=torch.int32), +# numpy.arange(n_rots, dtype=numpy.int32) * pbt.max_n_atoms, +# pbt.device, +# ) +# +# nodes, scans, gens = construct_scans_for_rotamers( +# pbt, block_type_ind_for_rot, n_atoms_for_rot, n_atoms_offset_for_rot +# ) +# +# orig_kinforest, orig_dofs_kto = measure_pose_dofs(poses) +# +# n_rotamer_atoms = torch.sum(n_atoms_for_rot).item() +# +# rot_dofs_kto = torch.zeros( +# (n_rotamer_atoms + 1, 9), dtype=torch.float32, device=pbt.device +# ) +# +# rot_dofs_kto[1:] = torch.tensor( +# pbt.rotamer_kinforest.dofs_ideal[block_type_ind_for_rot].reshape((-1, 9))[ +# pbt.atom_is_real.cpu().numpy()[block_type_ind_for_rot].reshape(-1) != 0 +# ], +# dtype=torch.float32, +# device=pbt.device, +# ) +# +# rt_for_rot_torch = torch.tensor(rt_for_rot, dtype=torch.int64, device=pbt.device) +# +# copy_dofs_from_orig_to_rotamers( +# poses, +# task, +# samplers, +# rt_for_rot_torch, +# block_type_ind_for_rot_torch, +# sampler_for_rotamer, +# torch.tensor(n_atoms_offset_for_rot, dtype=torch.int32, device=pbt.device), +# orig_dofs_kto, +# rot_dofs_kto, +# ) +# +# assign_dofs_from_samples( +# pbt, +# rt_for_rot_torch, +# block_type_ind_for_rot_torch, +# chi_atoms, +# chi, +# rot_dofs_kto, +# ) +# +# rotamer_coords = calculate_rotamer_coords( +# pbt, n_rots, rot_kinforest, nodes, scans, gens, rot_dofs_kto +# ) +# +# ( +# n_rots_for_pose, +# rot_offset_for_pose, +# n_rots_for_block, +# rot_offset_for_block, +# pose_for_rot, +# block_ind_for_rot, +# ) = get_rotamer_origin_data(task, rt_for_rot_torch) +# +# return ( +# poses, +# RotamerSet( +# n_rots_for_pose=n_rots_for_pose, +# rot_offset_for_pose=rot_offset_for_pose, +# n_rots_for_block=n_rots_for_block, +# rot_offset_for_block=rot_offset_for_block, +# pose_for_rot=pose_for_rot, +# block_type_ind_for_rot=block_type_ind_for_rot_torch, +# block_ind_for_rot=block_ind_for_rot, +# coords=rotamer_coords, +# ), +# ) + + def build_rotamers(poses: PoseStack, task: PackerTask, chem_db: ChemicalDatabase): + # step 1: replace the existing PBT in the Pose w/ a new one in case + # there will possibly be new block types in the repacked Pose; + # but since PoseStack should not be altered after construction, + # what this really means is build an entirely new PoseStack + # step 2: let the dunbrack library annotate the block types and the packed block types + # step 3: get the block-type index for the "global block types" (gbts) + # step 4: use the conformer samplers to decide how many conformers they will build for + # each bt/block/pose + # step 5: merge the conformer samples from different samplers, so that different + # conformers for the same bt/block/pose will be in contiguous ranges in the + # rotamer set, and keeping track of the mapping back to the original set of + # conformer samples + # step 6: allocate a n_atoms_total x 3 tensor for rotamer coordinates and + # create the tensor of offsets + # step 7: build a kintree for all of the rotamers; initialize the DOFs to ideal + # step 8: build a kintree for the PoseStack residues + # step 9: measure the DOFs of the PoseStack residues + # step 9a: take starting coordinates from residue roots + # step 9b: take internal dofs from mainchain-fingerprint atoms + # step 10: ask the samplers to set the DOFs for everything else + # step 11: refold + + # Step 1 poses, samplers = rebuild_poses_if_necessary(poses, task) pbt = poses.packed_block_types + + # Step 2 annotate_everything(chem_db, samplers, pbt) - rt_names = [ - rt.name - for one_pose_rlts in task.rlts - for rlt in one_pose_rlts - for rt in rlt.allowed_restypes + # Step 3 + # create a list of the name of every considered block type at every block in every + # pose so that we can then create an integer version of that same data; + # the "global block type" (gbt) if you will. The order in which these block- + # types appear will be used as an index for talking about which rotamers are + # built where. This cannot be efficient. Perhaps worth thinking hard about the + # PackerTask's structure. + gbt_names = [ + bt.name + for one_pose_blts in task.blts + for blt in one_pose_blts + for bt in blt.considered_block_types ] - rt_block_type_ind = pbt.restype_index.get_indexer(rt_names).astype(numpy.int32) + gbt_block_type_ind = pbt.restype_index.get_indexer(gbt_names).astype(numpy.int32) - chi_samples = [sampler.sample_chi_for_poses(poses, task) for sampler in samplers] - merged_samples = merge_chi_samples(chi_samples) - n_rots_for_rt, sampler_for_rotamer, rt_for_rotamer, chi_atoms, chi = merged_samples - - # fd NOTE: THIS CODE FAILS IF n_rots_for_rt CONTAINS 0s - assert 0 not in n_rots_for_rt - - n_rots = chi_atoms.shape[0] - rt_for_rot = torch.zeros(n_rots, dtype=torch.int64, device=poses.device) - n_rots_for_rt_cumsum = torch.cumsum(n_rots_for_rt, dim=0) - rt_for_rot[n_rots_for_rt_cumsum[:-1]] = 1 - rt_for_rot = torch.cumsum(rt_for_rot, dim=0).cpu().numpy() + # Step 4 + conformer_samples = [ + sampler.create_samples_for_poses(poses, task) for sampler in samplers + ] - block_type_ind_for_rot = rt_block_type_ind[rt_for_rot] - block_type_ind_for_rot_torch = torch.tensor( - block_type_ind_for_rot, dtype=torch.int64, device=pbt.device + # Step 5 + ( + n_rots_for_gbt, + sampler_for_conformer, + gbt_for_conformer, + conformer_built_by_sampler, + new_ind_for_sampler_rotamer, + ) = merge_conformer_samples(conformer_samples) + + # torch.set_printoptions(threshold=10000) + # print("n_rots_for_gbt") + # print(n_rots_for_gbt) + + def _t(t, dtype): + return torch.tensor(t, dtype=dtype, device=pbt.device) + + gbt_for_conformer_np = gbt_for_conformer.cpu().numpy() + + gbt_for_conformer_torch = _t(gbt_for_conformer, torch.int64) + + # apl: I hope to have fixed that + # fd NOTE: THIS CODE FAILS IF n_rots_for_gbt CONTAINS 0s + # assert 0 not in n_rots_for_gbt + + n_conformers = sampler_for_conformer.shape[0] + # gbt_for_rot = torch.zeros(n_conformers, dtype=torch.int64, device=poses.device) + # gbt_for_rot[n_rots_for_gbt_cumsum[:-1]] = 1 + # gbt_for_rot = torch.cumsum(gbt_for_rot, dim=0).cpu().numpy() + + block_type_ind_for_conformer = gbt_block_type_ind[gbt_for_conformer_np] + block_type_ind_for_conformer_torch = _t(block_type_ind_for_conformer, torch.int64) + + n_atoms_for_conformer = pbt.n_atoms[block_type_ind_for_conformer_torch] + n_atoms_offset_for_conformer = torch.cumsum(n_atoms_for_conformer, dim=0) + n_atoms_offset_for_conformer = n_atoms_offset_for_conformer.cpu().numpy() + n_atoms_total = n_atoms_offset_for_conformer[-1].item() + n_atoms_offset_for_conformer = exc_cumsum_from_inc_cumsum( + n_atoms_offset_for_conformer ) - n_atoms_for_rot = pbt.n_atoms[block_type_ind_for_rot_torch] - n_atoms_offset_for_rot = torch.cumsum(n_atoms_for_rot, dim=0) - n_atoms_offset_for_rot = n_atoms_offset_for_rot.cpu().numpy() - n_atoms_total = n_atoms_offset_for_rot[-1] - n_atoms_offset_for_rot = exc_cumsum_from_inc_cumsum(n_atoms_offset_for_rot) + n_atoms_offset_for_conformer_torch = _t(n_atoms_offset_for_conformer, torch.int64) - rot_kinforest = construct_kinforest_for_rotamers( + # Step 7 + conformer_kinforest = construct_kinforest_for_conformers( pbt, - block_type_ind_for_rot, - int(n_atoms_total), - torch.tensor(n_atoms_for_rot, dtype=torch.int32), - numpy.arange(n_rots, dtype=numpy.int32) * pbt.max_n_atoms, + block_type_ind_for_conformer, + n_atoms_total, + torch.tensor(n_atoms_for_conformer, dtype=torch.int32), + n_atoms_offset_for_conformer, pbt.device, ) - nodes, scans, gens = construct_scans_for_rotamers( - pbt, block_type_ind_for_rot, n_atoms_for_rot, n_atoms_offset_for_rot + nodes, scans, gens = construct_scans_for_conformers( + pbt, + block_type_ind_for_conformer, + n_atoms_for_conformer, + n_atoms_offset_for_conformer, ) - orig_dofs_kto = measure_pose_dofs(poses) + # Step 8 & 9 + orig_kinforest, orig_dofs_kto = measure_pose_dofs(poses) - n_rotamer_atoms = torch.sum(n_atoms_for_rot).item() - - rot_dofs_kto = torch.zeros( - (n_rotamer_atoms + 1, 9), dtype=torch.float32, device=pbt.device + # Step 9a + conf_dofs_kto = torch.zeros( + (n_atoms_total + 1, 9), dtype=torch.float32, device=pbt.device ) - - rot_dofs_kto[1:] = torch.tensor( - pbt.rotamer_kinforest.dofs_ideal[block_type_ind_for_rot].reshape((-1, 9))[ - pbt.atom_is_real.cpu().numpy()[block_type_ind_for_rot].reshape(-1) != 0 + conf_dofs_kto[1:] = torch.tensor( + pbt.rotamer_kinforest.dofs_ideal[block_type_ind_for_conformer].reshape((-1, 9))[ + pbt.atom_is_real.cpu().numpy()[block_type_ind_for_conformer].reshape(-1) + != 0 ], dtype=torch.float32, device=pbt.device, ) - rt_for_rot_torch = torch.tensor(rt_for_rot, dtype=torch.int64, device=pbt.device) - - copy_dofs_from_orig_to_rotamers( - poses, - task, - samplers, - rt_for_rot_torch, - block_type_ind_for_rot_torch, - sampler_for_rotamer, - torch.tensor(n_atoms_offset_for_rot, dtype=torch.int32, device=pbt.device), - orig_dofs_kto, - rot_dofs_kto, - ) + for i, sampler in enumerate(samplers): + sampler.fill_dofs_for_samples( + poses, + task, + orig_kinforest, + orig_dofs_kto, + gbt_for_conformer_torch, + block_type_ind_for_conformer_torch, + n_atoms_offset_for_conformer_torch, + conformer_built_by_sampler[i], + new_ind_for_sampler_rotamer[i], + conformer_samples[i][0], + conformer_samples[i][1], + conformer_samples[i][2], + conf_dofs_kto, + ) - assign_dofs_from_samples( - pbt, - rt_for_rot_torch, - block_type_ind_for_rot_torch, - chi_atoms, - chi, - rot_dofs_kto, - ) + # copy_dofs_from_orig_to_rotamers( + # poses, + # task, + # samplers, + # rt_for_rot_torch, + # block_type_ind_for_rot_torch, + # sampler_for_rotamer, + # torch.tensor(n_atoms_offset_for_rot, dtype=torch.int32, device=pbt.device), + # orig_dofs_kto, + # rot_dofs_kto, + # ) + # print("conf_dofs_kto") + # print( + # conf_dofs_kto[ + # torch.tensor( + # # [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + # # 10, 11, 12, 13, 14, 15, 16, 17, 18], + # [ + # 14952, + # 14953, + # 14954, + # 14955, + # 14956, + # 14957, + # 14958, + # 14959, + # 14960, + # 14961, + # 14962, + # 14963, + # 14964, + # 14965, + # 14956, + # 14957, + # ], + # dtype=torch.int64, + # device=pbt.device, + # ), + # :, + # ] + # ) + # print("id:") + # print(torch.nonzero(is_pro_rot)) + # print("conformer_kinforest.parent") + # print(conformer_kinforest.parent[14953:14967]) + # print("conformer_kinforest.frame_x") + # print(conformer_kinforest.frame_x[14953:14967]) + # print("conformer_kinforest.frame_y") + # print(conformer_kinforest.frame_y[14953:14967]) + # print("conformer_kinforest.frame_z") + # print(conformer_kinforest.frame_z[14953:14967]) + + # assign_dofs_from_samples( + # pbt, + # rt_for_rot_torch, + # block_type_ind_for_rot_torch, + # chi_atoms, + # chi, + # rot_dofs_kto, + # ) rotamer_coords = calculate_rotamer_coords( - pbt, n_rots, rot_kinforest, nodes, scans, gens, rot_dofs_kto + pbt, + n_conformers, + n_atoms_total, + conformer_kinforest, + nodes, + scans, + gens, + conf_dofs_kto, ) - ( n_rots_for_pose, rot_offset_for_pose, @@ -1017,18 +1120,20 @@ def build_rotamers(poses: PoseStack, task: PackerTask, chem_db: ChemicalDatabase rot_offset_for_block, pose_for_rot, block_ind_for_rot, - ) = get_rotamer_origin_data(task, rt_for_rot_torch) + ) = get_rotamer_origin_data(task, gbt_for_conformer_torch) return ( poses, RotamerSet( + max_n_rots_per_pose=torch.max(n_rots_for_pose).item(), n_rots_for_pose=n_rots_for_pose, rot_offset_for_pose=rot_offset_for_pose, n_rots_for_block=n_rots_for_block, rot_offset_for_block=rot_offset_for_block, pose_for_rot=pose_for_rot, - block_type_ind_for_rot=block_type_ind_for_rot_torch, + block_type_ind_for_rot=block_type_ind_for_conformer_torch, block_ind_for_rot=block_ind_for_rot, + coord_offset_for_rot=n_atoms_offset_for_conformer_torch.to(torch.int32), coords=rotamer_coords, ), ) diff --git a/tmol/pack/rotamer/chi_sampler.py b/tmol/pack/rotamer/chi_sampler.py index a8d2f018c..9dc7edc53 100644 --- a/tmol/pack/rotamer/chi_sampler.py +++ b/tmol/pack/rotamer/chi_sampler.py @@ -6,13 +6,16 @@ from tmol.types.torch import Tensor from tmol.types.functional import validate_args +from tmol.utility.tensor.common_operations import exclusive_cumsum1d, stretch from tmol.chemical.restypes import RefinedResidueType from tmol.pose.packed_block_types import PackedBlockTypes from tmol.pose.pose_stack import PoseStack +from tmol.kinematics.datatypes import KinForest +from tmol.pack.rotamer.conformer_sampler import ConformerSampler @attr.s(auto_attribs=True) -class ChiSampler: +class ChiSampler(ConformerSampler): @classmethod def sampler_name(cls): raise NotImplementedError() @@ -33,6 +36,31 @@ def defines_rotamers_for_rt(self, rt: RefinedResidueType): def first_sc_atoms_for_rt(self, rt_name: str) -> Tuple[str, ...]: raise NotImplementedError() + def create_samples_for_poses( + self, + pose_stack: PoseStack, + task: "PackerTask", # noqa: 821 + ) -> Tuple[ # noqa F821 + Tensor[torch.int32][:], # n_rots_for_gbt + Tensor[torch.int32][:], # bt_for_rotamer + dict, # anything else the sampler wants to save for later + ]: + ( + n_rots_for_gbt, + gbt_for_rotamer, + chi_defining_atom_for_rotamer, + chi_for_rotamers, + ) = self.sample_chi_for_poses(pose_stack, task) + # print("Sampling:", self.sampler_name(), chi_for_rotamers.shape) + return ( + n_rots_for_gbt, + gbt_for_rotamer, + dict( + chi_defining_atom_for_rotamer=chi_defining_atom_for_rotamer, + chi_for_rotamers=chi_for_rotamers, + ), + ) + def sample_chi_for_poses( self, systems: PoseStack, task: "PackerTask" # noqa F821 ) -> Tuple[ @@ -42,3 +70,360 @@ def sample_chi_for_poses( Tensor[torch.float32][:, :], # chi_for_rotamers ]: raise NotImplementedError() + + def fill_dofs_for_samples( + self, + pose_stack: PoseStack, + task: "PackerTask", # noqa: 821 + orig_kinforest: KinForest, + orig_dofs_kto: Tensor[torch.float32][:, 9], + gbt_for_conformer: Tensor[torch.int64][:], + block_type_ind_for_conformer: Tensor[torch.int64][:], + n_dof_atoms_offset_for_conformer: Tensor[torch.int64][:], + # which of all conformers are built by this sampler + conformer_built_by_sampler: Tensor[torch.bool][:], + # mapping orig conformer samples to merged conformer samples for this sampler + conf_inds_for_sampler: Tensor[torch.int64][:], + sampler_n_rots_for_gbt: Tensor[torch.int32][:], + sampler_gbt_for_rotamer: Tensor[torch.int32][:], + sample_dict: dict, + conf_dofs_kto: Tensor[torch.float32][:, 9], + ): + copy_dofs_from_orig_to_rotamers_for_sampler( + pose_stack, + task, + self.sampler_name(), + gbt_for_conformer, + block_type_ind_for_conformer, + conf_inds_for_sampler, + sampler_n_rots_for_gbt, + sampler_gbt_for_rotamer, + n_dof_atoms_offset_for_conformer, + orig_dofs_kto, + conf_dofs_kto, + ) + + chi_atoms = sample_dict["chi_defining_atom_for_rotamer"] + chi = sample_dict["chi_for_rotamers"] + if chi.shape[0] == 0: + return + + assign_chi_dofs_from_samples( + pose_stack.packed_block_types, + block_type_ind_for_conformer, + conf_inds_for_sampler, + sampler_n_rots_for_gbt, + sampler_gbt_for_rotamer, + n_dof_atoms_offset_for_conformer, + chi_atoms, + chi, + conf_dofs_kto, + ) + + +@validate_args +def copy_dofs_from_orig_to_rotamers_for_sampler( + poses: PoseStack, + task, + sampler_name: str, + gbt_for_rot: Tensor[torch.int64][:], + block_type_ind_for_rot: Tensor[torch.int64][:], + conf_inds_for_sampler: Tensor[torch.int64][:], + sampler_n_rots_for_gbt: Tensor[torch.int32][:], + sampler_gbt_for_rotamer: Tensor[torch.int32][:], + n_dof_atoms_offset_for_rot: Tensor[torch.int64][:], + orig_dofs_kto: Tensor[torch.float32][:, 9], + rot_dofs_kto: Tensor[torch.float32][:, 9], +): + dst, src = create_dof_inds_to_copy_from_orig_to_rotamers_for_sampler( + poses, + task, + sampler_name, + gbt_for_rot, + block_type_ind_for_rot, + conf_inds_for_sampler, + sampler_n_rots_for_gbt, + sampler_gbt_for_rotamer, + n_dof_atoms_offset_for_rot, + ) + + rot_dofs_kto[dst, :] = orig_dofs_kto[src, :] + + +def create_dof_inds_to_copy_from_orig_to_rotamers_for_sampler( + poses: PoseStack, + task: "PackerTask", # noqa F821 + sampler_name: str, + gbt_for_rot: Tensor[torch.int64][:], # max-n-rots + block_type_ind_for_rot: Tensor[torch.int64][:], + conf_inds_for_sampler: Tensor[torch.int64][:], + sampler_n_rots_for_gbt: Tensor[torch.int32][:], + sampler_gbt_for_rotamer: Tensor[torch.int32][:], + n_dof_atoms_offset_for_rot: Tensor[torch.int64][:], +) -> Tuple[Tensor[torch.int64][:], Tensor[torch.int64][:]]: + # we want to copy from the orig_dofs tensor into the + # rot_dofs tensor for the "mainchain" atoms in the + # original residues into the appropriate positions + # for the rotamers thta we are building at those + # residues. This requires a good deal of reindexing. + + pbt = poses.packed_block_types + n_rots_for_sampler = sampler_gbt_for_rotamer.shape[0] + + # This could 100% be pre-computed + pbts_sampler_ind = pbt.mc_fingerprints.sampler_mapping[sampler_name] + + orig_block_type_ind = ( + poses.block_type_ind[poses.block_type_ind != -1].view(-1).to(torch.int64) + ) + + # consider making this an argument and passing in + # print("poses.block_type_ind.shape", poses.block_type_ind.shape) + poses_res_to_real_poses_res = torch.full( + (poses.block_type_ind.shape[0] * poses.block_type_ind.shape[1],), + -1, + dtype=torch.int64, + device=poses.device, + ) + # print("poses_res_to_real_poses_res") + # print(poses_res_to_real_poses_res.shape) + # print(poses_res_to_real_poses_res[-10:]) + poses_res_to_real_poses_res[poses.block_type_ind.view(-1) != -1] = torch.arange( + orig_block_type_ind.shape[0], dtype=torch.int64, device=poses.device + ) + + # get the residue index for each rotamer + max_n_blocks = poses.block_coord_offset.shape[1] + res_ind_for_gbt = torch.tensor( + [ + i * max_n_blocks + j + for i, one_pose_blts in enumerate(task.blts) + for j, blt in enumerate(one_pose_blts) + for _ in blt.considered_block_types + ], + dtype=torch.int64, + device=poses.device, + ) + # print("res_ind_for_gbt") + # print(res_ind_for_gbt) + gbt_for_samplers_rots = gbt_for_rot[conf_inds_for_sampler] + # torch.set_printoptions(threshold=10000) + # print("gbt_for_samplers_rots") + # print(gbt_for_samplers_rots) + res_ind_for_samplers_rots = res_ind_for_gbt[gbt_for_samplers_rots] + # print("res_ind_for_samplers_rots") + # print(res_ind_for_samplers_rots) + real_res_ind_for_samplers_rots = poses_res_to_real_poses_res[ + res_ind_for_samplers_rots + ] + # print("real_res_ind_for_samplers_rots") + # print(real_res_ind_for_samplers_rots) + block_type_ind_for_samplers_rots = block_type_ind_for_rot[conf_inds_for_sampler] + + # look up which mainchain fingerprint each + # original residue should use + + mcfp = pbt.mc_fingerprints + + sampler_ind_for_orig = mcfp.max_sampler[orig_block_type_ind] + orig_res_mcfp = mcfp.max_fingerprint[orig_block_type_ind] + orig_res_mcfp_for_samplers_rots = orig_res_mcfp[real_res_ind_for_samplers_rots] + + # now lets find the kinforest-ordered indices of the + # mainchain atoms for the rotamers that represents + # the destination for the dofs we're copying + max_n_mcfp_atoms = mcfp.atom_mapping.shape[3] + + samplers_rots_mcfp_at_inds_rto = mcfp.atom_mapping[ + pbts_sampler_ind, + orig_res_mcfp_for_samplers_rots, + block_type_ind_for_samplers_rots, + :, + ].view(-1) + + is_samplers_rots_mcfp_at_inds_rto_real = samplers_rots_mcfp_at_inds_rto != -1 + real_samplers_rots_mcfp_at_inds_rto = samplers_rots_mcfp_at_inds_rto[ + is_samplers_rots_mcfp_at_inds_rto_real + ] + + # print("block_type_ind_for_samplers_rots", block_type_ind_for_samplers_rots.shape) + real_samplers_rots_block_type_ind_for_mcfp_ats = stretch( + block_type_ind_for_samplers_rots, max_n_mcfp_atoms + )[is_samplers_rots_mcfp_at_inds_rto_real] + + samplers_rots_mcfp_at_inds_kto = torch.full_like(samplers_rots_mcfp_at_inds_rto, -1) + samplers_rots_mcfp_at_inds_kto[is_samplers_rots_mcfp_at_inds_rto_real] = ( + torch.tensor( + pbt.rotamer_kinforest.kinforest_idx[ + real_samplers_rots_block_type_ind_for_mcfp_ats.cpu().numpy(), + real_samplers_rots_mcfp_at_inds_rto.cpu().numpy(), + ], + dtype=torch.int64, + device=pbt.device, + ) + ) + # print( + # "real_samplers_rots_block_type_ind_for_mcfp_ats", + # real_samplers_rots_block_type_ind_for_mcfp_ats.shape, + # ) + + is_samplers_rots_mcfp_at_inds_kto_real = samplers_rots_mcfp_at_inds_kto != -1 + # print( + # "is_samplers_rots_mcfp_at_inds_kto_real", + # is_samplers_rots_mcfp_at_inds_kto_real.shape, + # ) + # print( + # "n_rots_for_sampler * max_n_mcfp_atoms", n_rots_for_sampler * max_n_mcfp_atoms + # ) + n_dof_atoms_offset_for_samplers_rot = n_dof_atoms_offset_for_rot[ + conf_inds_for_sampler + ] + samplers_rots_mcfp_at_inds_kto[ + is_samplers_rots_mcfp_at_inds_kto_real + ] += n_dof_atoms_offset_for_samplers_rot[ + torch.div( # to do: replace with expand + torch.arange( + n_rots_for_sampler * max_n_mcfp_atoms, + dtype=torch.int64, + device=poses.device, + ), + max_n_mcfp_atoms, + rounding_mode="trunc", + )[is_samplers_rots_mcfp_at_inds_kto_real] + ] + + # now get the indices in the orig_dofs array for the atoms to copy from. + # The steps: + # 1. get the mainchain atom indices for each of the original residues + # in residue-type order (rto) + # 2. sample 1. for each rotamer + # 3. find the real subset of these atoms + # 4. note the residue index for each of these real atoms + # 5. remap these to kinforest order (kto) + # 6. increment the indices with the original-residue dof-index offsets + + # orig_mcfp_at_inds_for_orig_rto: + # 1. these are the mainchain fingerprint atoms from the original + # residues on the pose + # 2. they are stored in residue-type order (rto) + # 3. they are indexed by original residue index + + orig_mcfp_at_inds_rto = mcfp.atom_mapping[ + sampler_ind_for_orig, orig_res_mcfp, orig_block_type_ind, : + ].view(-1) + + real_orig_block_type_ind_for_orig_mcfp_ats = stretch( + orig_block_type_ind, max_n_mcfp_atoms + )[orig_mcfp_at_inds_rto != -1] + + orig_dof_atom_offset = exclusive_cumsum1d(pbt.n_atoms[orig_block_type_ind]).to( + torch.int64 + ) + + orig_mcfp_at_inds_kto = torch.full_like(orig_mcfp_at_inds_rto, -1) + orig_mcfp_at_inds_kto[orig_mcfp_at_inds_rto != -1] = ( + torch.tensor( + pbt.rotamer_kinforest.kinforest_idx[ + real_orig_block_type_ind_for_orig_mcfp_ats.cpu().numpy(), + orig_mcfp_at_inds_rto[orig_mcfp_at_inds_rto != -1].cpu().numpy(), + ], + dtype=torch.int64, + device=pbt.device, + ) + + orig_dof_atom_offset[ + torch.floor_divide( # to do: replace w/ expand + torch.arange( + orig_block_type_ind.shape[0] * max_n_mcfp_atoms, + dtype=torch.int64, + device=pbt.device, + ), + max_n_mcfp_atoms, + ) + ][orig_mcfp_at_inds_rto != -1] + ) + + orig_mcfp_at_inds_kto = orig_mcfp_at_inds_kto.view( + orig_block_type_ind.shape[0], max_n_mcfp_atoms + ) + + orig_mcfp_at_inds_for_samplers_rots_kto = orig_mcfp_at_inds_kto[ + real_res_ind_for_samplers_rots, : + ].view(-1) + + # pare down the subset to those where the mc atom is present for + # both the original block type and the alternate block type; + # take the subset and also increment the indices of all the atoms + # by one to take into account the virtual root atom at the origin + + both_present = torch.logical_and( + samplers_rots_mcfp_at_inds_kto != -1, + orig_mcfp_at_inds_for_samplers_rots_kto != -1, + ) + + # add one for the virtual root + samplers_rots_mcfp_at_inds_kto = samplers_rots_mcfp_at_inds_kto[both_present] + 1 + orig_mcfp_at_inds_for_samplers_rots_kto = ( + orig_mcfp_at_inds_for_samplers_rots_kto[both_present] + 1 + ) + + # print("samplers_rots_mcfp_at_inds_kto") + # print(samplers_rots_mcfp_at_inds_kto.shape) + # print(samplers_rots_mcfp_at_inds_kto[:30]) + # print("orig_mcfp_at_inds_for_samplers_rots_kto") + # print(orig_mcfp_at_inds_for_samplers_rots_kto.shape) + # print(orig_mcfp_at_inds_for_samplers_rots_kto[:30]) + + return samplers_rots_mcfp_at_inds_kto, orig_mcfp_at_inds_for_samplers_rots_kto + + +@validate_args +def assign_chi_dofs_from_samples( + pbt: PackedBlockTypes, + block_type_ind_for_rot: Tensor[torch.int64][:], + conf_inds_for_sampler: Tensor[torch.int64][:], + sampler_n_rots_for_bt: Tensor[torch.int32][:], + sampler_gbt_for_rotamer: Tensor[torch.int32][:], + n_dof_atoms_offset_for_rot: Tensor[torch.int64][:], + chi_atoms: Tensor[torch.int32][:, :], + chi: Tensor[torch.float32][:, :], + rot_dofs_kto: Tensor[torch.float32][:, 9], +): + assert chi_atoms.shape == chi.shape + + n_rots_for_sampler = sampler_gbt_for_rotamer.shape[0] + + max_n_chi_atoms = chi_atoms.shape[1] + real_atoms = chi_atoms.view(-1) != -1 + + sampler_rot_ind_for_real_atom = torch.floor_divide( # to do: replace w/ expand + torch.arange( + max_n_chi_atoms * n_rots_for_sampler, dtype=torch.int64, device=pbt.device + ), + max_n_chi_atoms, + )[real_atoms] + global_rot_ind_for_real_atom = conf_inds_for_sampler[sampler_rot_ind_for_real_atom] + + block_type_ind_for_rot_atom = ( + block_type_ind_for_rot[global_rot_ind_for_real_atom].cpu().numpy() + ) + + rot_chi_atoms_kto = torch.tensor( + pbt.rotamer_kinforest.kinforest_idx[ + block_type_ind_for_rot_atom, chi_atoms.view(-1)[real_atoms].cpu().numpy() + ], + dtype=torch.int64, + device=pbt.device, + ) + + # increment with the atom offsets for the source rotamer and by + # one to include the virtual root + rot_chi_atoms_kto += ( + n_dof_atoms_offset_for_rot[global_rot_ind_for_real_atom].to(torch.int64) + 1 + ) + + # overwrite the "downstream torsion" for the atoms that control + # each chi + rot_dofs_kto[rot_chi_atoms_kto, 3] = chi.view(-1)[real_atoms] + + # print("rot_chi_atoms_kto", rot_chi_atoms_kto[:10]) + # print("chi", chi.view(-1)[real_atoms][:10]) diff --git a/tmol/pack/rotamer/conformer_sampler.py b/tmol/pack/rotamer/conformer_sampler.py new file mode 100644 index 000000000..2a2c156a0 --- /dev/null +++ b/tmol/pack/rotamer/conformer_sampler.py @@ -0,0 +1,66 @@ +import torch +import attr + +from typing import Tuple + +from tmol.types.torch import Tensor +from tmol.types.functional import validate_args + +from tmol.chemical.restypes import RefinedResidueType +from tmol.pose.packed_block_types import PackedBlockTypes +from tmol.pose.pose_stack import PoseStack +from tmol.kinematics.datatypes import KinForest + + +@attr.s(auto_attribs=True) +class ConformerSampler: + @classmethod + def sampler_name(cls): + raise NotImplementedError() + + @validate_args + def annotate_residue_type(self, rt: RefinedResidueType): + pass + + @validate_args + def annotate_packed_block_types(self, packed_block_types: PackedBlockTypes): + pass + + @validate_args + def defines_rotamers_for_rt(self, rt: RefinedResidueType): + raise NotImplementedError() + + @validate_args + def first_sc_atoms_for_rt(self, rt: RefinedResidueType) -> Tuple[str, ...]: + raise NotImplementedError() + + def create_samples_for_poses( + self, + pose_stack: PoseStack, + task: "PackerTask", # noqa: 821 + ) -> Tuple[ # noqa F821 + Tensor[torch.int32][:], # n_rots_for_bt + Tensor[torch.int32][:], # bt_for_rotamer + dict, # anything else the sampler wants to save for later + ]: + raise NotImplementedError() + + def fill_dofs_for_samples( + self, + pose_stack: PoseStack, + task: "PackerTask", # noqa: 821 + orig_kinforest: KinForest, + orig_dofs_kto: Tensor[torch.float32][:, 9], + gbt_for_conformer: Tensor[torch.int64][:], + block_type_ind_for_conformer: Tensor[torch.int64][:], + n_dof_atoms_offset_for_conformer: Tensor[torch.int64][:], + # which of all conformers are built by this sampler + conformer_built_by_sampler: Tensor[torch.bool][:], + # mapping orig conformer samples to merged conformer samples for this sampler + conf_inds_for_sampler: Tensor[torch.int64][:], + sampler_n_rots_for_gbt: Tensor[torch.int32][:], + sampler_gbt_for_rotamer: Tensor[torch.int32][:], + sample_dict: dict, + conf_dofs_kto: Tensor[torch.float32][:, 9], + ): + raise NotImplementedError diff --git a/tmol/pack/rotamer/dunbrack/dunbrack_chi_sampler.py b/tmol/pack/rotamer/dunbrack/dunbrack_chi_sampler.py index 4905f9529..4d03aa397 100644 --- a/tmol/pack/rotamer/dunbrack/dunbrack_chi_sampler.py +++ b/tmol/pack/rotamer/dunbrack/dunbrack_chi_sampler.py @@ -55,7 +55,7 @@ def max_n_chi(self): # memoized. So each database should construct one and only one # ParamResolver. # @attr.s(auto_attribs=True, slots=True, frozen=True) -class DunbrackChiSampler: +class DunbrackChiSampler(ChiSampler): dun_param_resolver: DunbrackParamResolver def __eq__(self, other): @@ -86,7 +86,7 @@ def annotate_residue_type(self, restype: RefinedResidueType): if hasattr(restype, "dun_sampler_cache"): return - # #chi = 2; #atoms in a dihedral = 4; #entries in a uaid = 3 + # n-bb-dihedrals = 2; n-atoms in a dihedral = 4; n-entries in a uaid = 3 uaids = numpy.full((2, 4, 3), -1, dtype=numpy.int32) if "phi" in restype.torsion_to_uaids: uaids[0] = numpy.array(restype.torsion_to_uaids["phi"], dtype=numpy.int32) @@ -305,56 +305,84 @@ def sample_chi_for_poses( ]: assert self.device == pose_stack.coords.device max_n_blocks = pose_stack.block_type_ind.shape[1] - - dun_allowed_restypes = numpy.array( + # there are three sets of block types: + # 1. the global-block-type list: all considered block-types at all positions (gbt) + # 2. the dunbrack-allowed list: all allowed block types at all positions that + # contain this DunbrackChiSampler in their set of conformer samplers (dun-allowed) + # 3. the buildable-block-types: the allowed block types at all positions that + # contain this DunbrackChiSampler in their set of conformer samplers that + # the DunbrackChiSampler will sample rotamers for (bbt) + + # the subset of blocktypes which are allowed at the positions and + # for which the block-level tasks include this DunbrackChiSampler + # the "Dunbrack allowed" restypes + dun_allowed_blocktypes = numpy.array( [ - rt - for one_pose_rlts in task.rlts - for rlt in one_pose_rlts - for rt in rlt.allowed_restypes - if self in rlt.chi_samplers + bt + for one_pose_blts in task.blts + for blt in one_pose_blts + for i, bt in enumerate(blt.considered_block_types) + if self in blt.conformer_samplers and blt.block_type_allowed[i] ], dtype=object, ) + # print("dun_allowed_blocktypes shape", dun_allowed_blocktypes.shape) + is_gbt_dun_allowed = numpy.array( + [ + self in blt.conformer_samplers and blt.block_type_allowed[i] + for one_pose_blts in task.blts + for blt in one_pose_blts + for i, bt in enumerate(blt.considered_block_types) + ], + dtype=bool, + ) + n_gbt_total = is_gbt_dun_allowed.shape[0] + # print("is_gbt_dun_allowed shape", is_gbt_dun_allowed.shape) + # equiv: numpy.nonzero(is_gbt_dun_allowed) + dun_allowed_bt_to_gbt = numpy.arange(n_gbt_total, dtype=numpy.int64)[ + is_gbt_dun_allowed + ] + dun_allowed_bt_to_gbt_torch = torch.tensor( + dun_allowed_bt_to_gbt, device=self.device + ) - # n_allowed_per_pose = torch.tensor( - # [ - # len(rlt.allowed_restypes) - # for one_pose_rlts in task.rlts - # for rlt in one_pose_rlts - # if self in rlt.chi_samplers - # ], - # dtype=torch.int32, - # device=self.device, - # ) - - rt_names = numpy.array([rt.name for rt in dun_allowed_restypes], dtype=object) - rt_base_names = numpy.array( - [rt.name.partition(":")[0] for rt in dun_allowed_restypes], dtype=object + dun_allowed_bt_names = numpy.array( + [bt.name for bt in dun_allowed_blocktypes], dtype=object ) + dun_allowed_bt_base_names = numpy.array( + [bt.name.partition(":")[0] for bt in dun_allowed_blocktypes], dtype=object + ) + # print("dun_allowed_bt_base_names[:20]", dun_allowed_bt_base_names[:20]) pbt = pose_stack.packed_block_types - rt_res = torch.tensor( + # the source block for each dun-allowed block type + dun_allowed_bt_block = torch.tensor( [ i * max_n_blocks + j - for i, one_pose_rlts in enumerate(task.rlts) - for j, rlt in enumerate(one_pose_rlts) - for rt in rlt.allowed_restypes + for i, one_pose_blts in enumerate(task.blts) + for j, blt in enumerate(one_pose_blts) + for k, _ in enumerate(blt.considered_block_types) + if blt.block_type_allowed[k] and self in blt.conformer_samplers ], dtype=torch.int32, device=self.device, ) - dun_rot_inds_for_rts = self.dun_param_resolver._indices_from_names( + dun_rot_inds_for_dun_allowed_bts = self.dun_param_resolver._indices_from_names( self.dun_param_resolver.all_table_indices, - rt_base_names[None, :], + dun_allowed_bt_base_names[None, :], # ??? torch.device("cpu"), self.device, ).squeeze() - block_type_ind_for_brt = torch.tensor( + # the pbt-assigned block-type indices for each buildable block type + # the subset of dun_rot_inds_for_dun_allowed_bts with a non-sentinel + # value represents the buildable block types + block_type_ind_for_bbt = torch.tensor( pbt.restype_index.get_indexer( - rt_names[dun_rot_inds_for_rts.cpu().numpy() != -1] + dun_allowed_bt_names[ + dun_rot_inds_for_dun_allowed_bts.cpu().numpy() != -1 + ] ), dtype=torch.int64, device=self.device, @@ -367,15 +395,12 @@ def sample_chi_for_poses( -1, 4 ) - # fd unused - # phi_psi_inds = torch.cat( - # (inds_of_phi.reshape(-1, 4), inds_of_psi.reshape(-1, 4)), dim=1 - # ) - # phi_psi_inds = phi_psi_inds.reshape(-1, 4) - - nonzero_dunrot_inds_for_rts = torch.nonzero(dun_rot_inds_for_rts != -1) - rottable_set_for_buildable_restype = dun_rot_inds_for_rts[ - nonzero_dunrot_inds_for_rts + is_dun_allowed_bt_bbt = dun_rot_inds_for_dun_allowed_bts != -1 + dun_allowed_bt_that_are_bbt = torch.nonzero(is_dun_allowed_bt_bbt)[:, 0] + # print("dun_allowed_bt_that_are_bbt", dun_allowed_bt_that_are_bbt.shape) + bbt_to_gbt_torch = dun_allowed_bt_to_gbt_torch[dun_allowed_bt_that_are_bbt] + rottable_set_for_bbt = dun_rot_inds_for_dun_allowed_bts[ + dun_allowed_bt_that_are_bbt ] # the "indices" of the blocks that the block types we will be building come @@ -383,19 +408,17 @@ def sample_chi_for_poses( # numbering. We will need to keep this array as it will be used by the # caller to understand what block types we are defining samples for. # We will shortly be renumbering the residues to talk about only the ones - # that we will build rotamers for - orig_residue_for_buildable_restype = rt_res[nonzero_dunrot_inds_for_rts] + # that we will build rotamers for: BRT = "buildable residue type" + block_for_bbt = dun_allowed_bt_block[dun_allowed_bt_that_are_bbt] - uniq_res_for_brt, uniq_inds = torch.unique( - orig_residue_for_buildable_restype, return_inverse=True - ) - uniq_res_for_brt = uniq_res_for_brt.to(torch.int64) + uniq_block_for_bbt, uniq_inds = torch.unique(block_for_bbt, return_inverse=True) + uniq_block_for_bbt = uniq_block_for_bbt.to(torch.int64) - rottable_set_for_buildable_restype = torch.tensor( + rottable_set_for_bbt = torch.tensor( torch.cat( ( uniq_inds.reshape(-1, 1), - rottable_set_for_buildable_restype.reshape(-1, 1), + rottable_set_for_bbt.reshape(-1, 1), ), dim=1, ), @@ -405,83 +428,99 @@ def sample_chi_for_poses( # phi_psi_res_inds = numpy.arange(n_sys * max_n_blocks, dtype=numpy.int32) - n_sampling_res = uniq_res_for_brt.shape[0] + n_sampling_blocks = uniq_block_for_bbt.shape[0] # map the residue-numbered list of dihedral angles to their positions in - # the set of residues that the Dunbrack library will provice chi samples for + # the set of residues that the Dunbrack library will provide chi samples for dihedral_atom_inds = torch.full( - (2 * n_sampling_res, 4), -1, dtype=torch.int32, device=self.device + (2 * n_sampling_blocks, 4), -1, dtype=torch.int32, device=self.device ) dihedral_atom_inds[ - 2 * torch.arange(n_sampling_res, dtype=torch.int64, device=self.device), : - ] = inds_of_phi[uniq_res_for_brt, :] + 2 * torch.arange(n_sampling_blocks, dtype=torch.int64, device=self.device), + :, + ] = inds_of_phi[uniq_block_for_bbt, :] dihedral_atom_inds[ - 2 * torch.arange(n_sampling_res, dtype=torch.int64, device=self.device) + 1, + 2 * torch.arange(n_sampling_blocks, dtype=torch.int64, device=self.device) + + 1, :, - ] = inds_of_psi[uniq_res_for_brt, :] + ] = inds_of_psi[uniq_block_for_bbt, :] - ndihe_for_res = torch.full( - (n_sampling_res,), 2, dtype=torch.int32, device=self.device + n_dihe_for_block = torch.full( + (n_sampling_blocks,), 2, dtype=torch.int32, device=self.device ) - dihedral_offset_for_res = 2 * torch.arange( - n_sampling_res, dtype=torch.int32, device=self.device + dihedral_offset_for_block = 2 * torch.arange( + n_sampling_blocks, dtype=torch.int32, device=self.device ) - n_brts = nonzero_dunrot_inds_for_rts.shape[0] + n_bbts = dun_allowed_bt_that_are_bbt.shape[0] max_n_chi = pose_stack.packed_block_types.dun_sampler_cache.max_n_chi - chi_expansion_for_buildable_restype = torch.full( - (n_brts, max_n_chi), 0, dtype=torch.int32, device=self.device + chi_expansion_for_bbt = torch.full( + (n_bbts, max_n_chi), 0, dtype=torch.int32, device=self.device ) + chi_expansion_for_gbt = torch.cat( + [ + torch.tensor(blt.chi_expansion) + for one_pose_blts in task.blts + for blt in one_pose_blts + ], + ).to(self.device) + chi_expansion_for_bbt = (chi_expansion_for_gbt[is_gbt_dun_allowed])[ + is_dun_allowed_bt_bbt + ] + # chi_expansion_for_bbt = chi_expansion_for_bbt - # ok, we'll go to the residue types and look at their protonation - # state expansions aand we'll put that information into the + # ok, we'll go to the block types and look at their protonation + # state expansions and we'll put that information into the # chi_expansions_for_buildable_restype tensor sampling_db = self.dun_param_resolver.sampling_db - nchi_for_buildable_restype = sampling_db.nchi_for_table_set[ - rottable_set_for_buildable_restype[:, 1].to(torch.int64) + n_chi_for_bbt = sampling_db.nchi_for_table_set[ + rottable_set_for_bbt[:, 1].to(torch.int64) ] - non_dunbrack_expansion_counts_for_buildable_restype = torch.zeros( - (n_brts, max_n_chi), dtype=torch.int32, device=self.device + non_dunbrack_expansion_counts_for_bbt = torch.zeros( + (n_bbts, max_n_chi), dtype=torch.int32, device=self.device ) - # max_chi_samples = 0 # TEMP! Treat everything as exposed (0) sc = pbt.dun_sampler_cache - ndecfbr = sc.non_dunbrack_sample_counts[block_type_ind_for_brt, 0] - non_dunbrack_expansion_counts_for_buildable_restype = ndecfbr + ndecfbbt = sc.non_dunbrack_sample_counts[block_type_ind_for_bbt, 0] + non_dunbrack_expansion_counts_for_bbt = ndecfbbt # TEMP! Treat everything as exposed (0) - non_dunbrack_expansion_for_buildable_restype = sc.non_dunbrack_samples[ - block_type_ind_for_brt, 0 + non_dunbrack_expansion_for_bbt = sc.non_dunbrack_samples[ + block_type_ind_for_bbt, 0 ] # treat all residues as if they are exposed - prob_cumsum_limit_for_buildable_restype = torch.full( - (n_brts,), 0.95, dtype=torch.float32, device=self.device + prob_cumsum_limit_for_bbt = torch.full( + (n_bbts,), 0.95, dtype=torch.float32, device=self.device ) + # the sampled chi returned are a tuple containing info for BBTs: + # these have to be mapped back to info for GBTs, which is handled + # in the next step sampled_chi = self.launch_rotamer_building( pose_stack.coords.reshape(-1, 3), - ndihe_for_res, - dihedral_offset_for_res, + n_dihe_for_block, + dihedral_offset_for_block, dihedral_atom_inds, - rottable_set_for_buildable_restype, - chi_expansion_for_buildable_restype, - non_dunbrack_expansion_for_buildable_restype, - non_dunbrack_expansion_counts_for_buildable_restype, - prob_cumsum_limit_for_buildable_restype, - nchi_for_buildable_restype, + rottable_set_for_bbt, + chi_expansion_for_bbt, + non_dunbrack_expansion_for_bbt, + non_dunbrack_expansion_counts_for_bbt, + prob_cumsum_limit_for_bbt, + n_chi_for_bbt, ) return self.package_samples_for_output( pbt, task, - block_type_ind_for_brt, + n_gbt_total, + bbt_to_gbt_torch, + block_type_ind_for_bbt, max_n_chi, - nonzero_dunrot_inds_for_rts, sampled_chi, ) @@ -611,33 +650,52 @@ def package_samples_for_output( self, pbt: PackedBlockTypes, task: PackerTask, + n_gbt_total: int, + bbt_to_gbt: Tensor[torch.int64][:], block_type_ind_for_brt: Tensor[torch.int64][:], max_n_chi: int, - nonzero_dunrot_inds_for_rts: Tensor[torch.int64][:, :], sampled_chi, ): - restype_is_allowed_for_dun = torch.tensor( - [ - True if self in rlt.chi_samplers else False - for one_pose_rlts in task.rlts - for rlt in one_pose_rlts - for rt in rlt.allowed_restypes - ], - dtype=torch.uint8, - device=self.device, - ) - n_restypes_total = restype_is_allowed_for_dun.shape[0] - dun_allowed_inds = torch.nonzero(restype_is_allowed_for_dun)[:, 0] - dun_brt_global_inds = dun_allowed_inds[nonzero_dunrot_inds_for_rts[:, 0]].to( - self.device - ) + n_bbt = bbt_to_gbt.shape[0] + # print("n_bbt", n_bbt) + # print("sampled_chi[0].shape[0]", sampled_chi[0].shape[0]) + # print("sampled_chi[1].shape[0]", sampled_chi[1].shape[0]) + assert sampled_chi[0].shape[0] == n_bbt + assert sampled_chi[1].shape[0] == n_bbt + + # restype_is_allowed_for_dun = torch.tensor( + # [ + # ( + # True + # if self in blt.conformer_samplers and blt.block_type_allowed[i] + # else False + # ) + # for one_pose_blts in task.blts + # for blt in one_pose_blts + # for i, bt in enumerate(blt.considered_block_types) + # ], + # dtype=torch.uint8, + # device=self.device, + # ) + # n_restypes_total = restype_is_allowed_for_dun.shape[0] + # dun_allowed_inds = torch.nonzero(restype_is_allowed_for_dun)[:, 0] + + # temp! + # torch.set_printoptions(threshold=10000) + # print("nonzero_dunrot_inds_for_rts[:, 0]") + # print(nonzero_dunrot_inds_for_rts[:, 0]) + # dun_brt_global_inds = dun_allowed_inds[nonzero_dunrot_inds_for_rts[:, 0]].to( + # self.device + # ) + # print("dun_brt_global_inds") + # print(dun_brt_global_inds) n_rots_for_rt = torch.zeros( - (n_restypes_total,), dtype=torch.int32, device=self.device + (n_gbt_total,), dtype=torch.int32, device=self.device ) - n_rots_for_rt[dun_brt_global_inds] = sampled_chi[0] + n_rots_for_rt[bbt_to_gbt] = sampled_chi[0] n_rots_for_rt_offsets = torch.zeros_like(n_rots_for_rt) - n_rots_for_rt_offsets[dun_brt_global_inds] = sampled_chi[1] + n_rots_for_rt_offsets[bbt_to_gbt] = sampled_chi[1] # n_rots_for_brt = sampled_chi[0] # n_rots_for_brt_offsets = sampled_chi[1] @@ -647,17 +705,24 @@ def package_samples_for_output( # Now lets map back to the original set of rts per block type. # lots of reindxing below # max_n_rts = max( - # len(rts.allowed_restypes) - # for one_pose_rlts in task.rlts - # for rts in one_pose_rlts + # len(rts.allowed_blocktypes) + # for one_pose_blts in task.blts + # for rts in one_pose_blts # ) - rt_global_index = torch.arange( - n_restypes_total, dtype=torch.int32, device=self.device - ) - global_rt_ind_for_brt = rt_global_index[nonzero_dunrot_inds_for_rts.squeeze()] + # rt_global_index = torch.arange( + # n_restypes_total, dtype=torch.int32, device=self.device + # ) + # global_rt_ind_for_brt = rt_global_index[nonzero_dunrot_inds_for_rts.squeeze()] - rt_for_rotamer = global_rt_ind_for_brt[brt_for_rotamer.to(torch.int64)] + # print("brt_for_rotamer") + # print(brt_for_rotamer) + # rt_for_rotamer = global_rt_ind_for_brt[brt_for_rotamer.to(torch.int64)] + gbt_for_rotamer = bbt_to_gbt[brt_for_rotamer].to(torch.int32) + # print("gbt_for_rotamer") + # print(gbt_for_rotamer) + # print("global_rt_ind_for_brt") + # print(global_rt_ind_for_brt) pbt_cda = pbt.dun_sampler_cache.chi_defining_atom chi_defining_atom_for_rotamer = torch.full( @@ -683,10 +748,9 @@ def package_samples_for_output( chi_defining_atom_for_rotamer[:, :max_n_chi] = pbt_cda[ block_type_ind_for_brt[brt_for_rotamer.to(torch.int64)], :max_n_chi ] - return ( n_rots_for_rt, - rt_for_rotamer, + gbt_for_rotamer, chi_defining_atom_for_rotamer, chi_for_rotamers, ) diff --git a/tmol/pack/rotamer/fixed_aa_chi_sampler.py b/tmol/pack/rotamer/fixed_aa_chi_sampler.py index 0c0a41159..68fb297a8 100644 --- a/tmol/pack/rotamer/fixed_aa_chi_sampler.py +++ b/tmol/pack/rotamer/fixed_aa_chi_sampler.py @@ -52,32 +52,48 @@ def sample_chi_for_poses( ]: all_restypes = numpy.array( [ - rt - for one_pose_rlts in task.rlts - for rlt in one_pose_rlts - for rt in rlt.allowed_restypes - if self in rlt.chi_samplers + bt + for one_pose_blts in task.blts + for blt in one_pose_blts + for i, bt in enumerate(blt.considered_block_types) + if self in blt.conformer_samplers ], dtype=object, ) + restype_allowed = torch.tensor( + [ + (self in blt.conformer_samplers and bool(blt.block_type_allowed[i])) + for one_pose_blts in task.blts + for blt in one_pose_blts + for i, bt in enumerate(blt.considered_block_types) + ], + dtype=bool, + device=poses.device, + ) rt_base_names = numpy.array([rt.base_name for rt in all_restypes], dtype=object) n_rots_for_rt = torch.zeros( len(all_restypes), dtype=torch.int32, device=poses.device ) - is_ala_rt = torch.tensor( - (rt_base_names == "ALA"), - dtype=torch.bool, - device=poses.device, + is_allowed_ala_rt = torch.logical_and( + torch.tensor( + (rt_base_names == "ALA"), + dtype=torch.bool, + device=poses.device, + ), + restype_allowed, ) - is_gly_rt = torch.tensor( - (rt_base_names == "GLY"), - dtype=torch.bool, - device=poses.device, + is_allowed_gly_rt = torch.logical_and( + torch.tensor( + (rt_base_names == "GLY"), + dtype=torch.bool, + device=poses.device, + ), + restype_allowed, ) - n_rots_for_rt[is_ala_rt] += 1 - n_rots_for_rt[is_gly_rt] += 1 - either_ala_or_gly = torch.logical_or(is_ala_rt, is_gly_rt) + n_rots_for_rt[is_allowed_ala_rt] += 1 + n_rots_for_rt[is_allowed_gly_rt] += 1 + either_ala_or_gly = torch.logical_or(is_allowed_ala_rt, is_allowed_gly_rt) n_fixed_rots = torch.sum(n_rots_for_rt).item() # rt_for_rotamer = torch.zeros( @@ -88,6 +104,7 @@ def sample_chi_for_poses( rt_for_rotamer = torch.arange( len(rt_base_names), dtype=torch.int32, device=poses.device )[either_ala_or_gly] + # print("fixed_aa_chi_sampler rt for rotamer", rt_for_rotamer) chi_for_rotamers = torch.zeros( (n_fixed_rots, 1), dtype=torch.float32, device=poses.device ) diff --git a/tmol/pack/rotamer/include_current_sampler.py b/tmol/pack/rotamer/include_current_sampler.py new file mode 100644 index 000000000..cfce5c9d0 --- /dev/null +++ b/tmol/pack/rotamer/include_current_sampler.py @@ -0,0 +1,194 @@ +import numpy +import torch +import attr + +from typing import Tuple + +from tmol.types.torch import Tensor +from tmol.types.functional import validate_args + +from tmol.utility.tensor.common_operations import exclusive_cumsum1d +from tmol.chemical.restypes import RefinedResidueType +from tmol.pose.packed_block_types import PackedBlockTypes +from tmol.pose.pose_stack import PoseStack +from tmol.kinematics.datatypes import KinForest +from tmol.pack.rotamer.conformer_sampler import ConformerSampler + + +@attr.s(auto_attribs=True, frozen=True) +class IncludeCurrentSampler(ConformerSampler): + + @classmethod + def sampler_name(cls): + return "IncludeCurrentSampler" + + @validate_args + def annotate_residue_type(self, rt: RefinedResidueType): + pass + + @validate_args + def annotate_packed_block_types(self, packed_block_types: PackedBlockTypes): + pass + + @validate_args + def defines_rotamers_for_rt(self, rt: RefinedResidueType): + return True + + @validate_args + def first_sc_atoms_for_rt(self, rt: RefinedResidueType) -> Tuple[str, ...]: + return (rt.default_jump_connection_atom,) + + def create_samples_for_poses( + self, + pose_stack: PoseStack, + task: "PackerTask", # noqa: 821 + ) -> Tuple[ # noqa F821 + Tensor[torch.int32][:], # n_rots_for_gbt + Tensor[torch.int32][:], # gbt_for_rotamer + dict, # anything else the sampler wants to save for later + ]: + n_rots_for_gbt_list = [ + ( + 1 + if bt is blt.original_block_type + and (blt.include_current or not numpy.any(blt.block_type_allowed)) + else 0 + ) + for one_pose_blts in task.blts + for blt in one_pose_blts + for bt in blt.considered_block_types + ] + n_rots_for_gbt = torch.tensor( + n_rots_for_gbt_list, dtype=torch.int32, device=pose_stack.device + ) + gbt_for_rotamer = torch.nonzero(n_rots_for_gbt, as_tuple=True)[0] + return (n_rots_for_gbt, gbt_for_rotamer, {}) + + def fill_dofs_for_samples( + self, + pose_stack: PoseStack, + task: "PackerTask", # noqa: 821 + orig_kinforest: KinForest, + orig_dofs_kto: Tensor[torch.float32][:, 9], + gbt_for_conformer: Tensor[torch.int64][:], + block_type_ind_for_conformer: Tensor[torch.int64][:], + n_dof_atoms_offset_for_conformer: Tensor[torch.int64][:], + # which of all conformers are built by this sampler + conformer_built_by_sampler: Tensor[torch.bool][:], + # mapping orig conformer samples to merged conformer samples for this sampler + conf_inds_for_sampler: Tensor[torch.int64][:], + sampler_n_rots_for_gbt: Tensor[torch.int32][:], + sampler_gbt_for_rotamer: Tensor[torch.int32][:], + sample_dict: dict, + conf_dofs_kto: Tensor[torch.float32][:, 9], + ): + n_rots = sampler_gbt_for_rotamer.shape[0] + if n_rots == 0: + return + + if torch.cuda.is_available(): + torch.cuda.synchronize() + dst, src = ( + create_full_dof_inds_to_copy_from_orig_to_rotamers_for_include_current_sampler( + pose_stack, + task, + gbt_for_conformer, + block_type_ind_for_conformer, + conf_inds_for_sampler, + sampler_n_rots_for_gbt, + sampler_gbt_for_rotamer, + n_dof_atoms_offset_for_conformer, + ) + ) + + conf_dofs_kto[dst + 1, :] = orig_dofs_kto[src + 1, :] + if torch.cuda.is_available(): + torch.cuda.synchronize() + + +# @validate_args +def create_full_dof_inds_to_copy_from_orig_to_rotamers_for_include_current_sampler( + poses: PoseStack, + task: "PackerTask", # noqa F821 + gbt_for_rot: Tensor[torch.int64][:], # max-n-rots + block_type_ind_for_rot: Tensor[torch.int64][:], + conf_inds_for_sampler: Tensor[torch.int64][:], + sampler_n_rots_for_gbt: Tensor[torch.int32][:], + sampler_gbt_for_rotamer: Tensor[torch.int32][:], + n_dof_atoms_offset_for_rot: Tensor[torch.int64][:], +) -> Tuple[Tensor[torch.int64][:], Tensor[torch.int64][:]]: + # we want to copy from the orig_dofs tensor into the + # rot_dofs tensor for the "mainchain" atoms in the + # original residues into the appropriate positions + # for the rotamers thta we are building at those + # residues. This requires a good deal of reindexing. + + pbt = poses.packed_block_types + n_rots_for_sampler = sampler_gbt_for_rotamer.shape[0] + + orig_block_type_ind = ( + poses.block_type_ind[poses.block_type_ind != -1].view(-1).to(torch.int64) + ) + orig_dof_atom_offset = exclusive_cumsum1d(pbt.n_atoms[orig_block_type_ind]).to( + torch.int64 + ) # TO DO: pass this in as an input parameter as each Sampler needs it + + # TO DO: make this an argument and pass it in + poses_res_to_real_poses_res = torch.full( + (poses.block_type_ind.shape[0] * poses.block_type_ind.shape[1],), + -1, + dtype=torch.int64, + device=poses.device, + ) + poses_res_to_real_poses_res[poses.block_type_ind.view(-1) != -1] = torch.arange( + orig_block_type_ind.shape[0], dtype=torch.int64, device=poses.device + ) + + # get the residue index for each rotamer + max_n_blocks = poses.block_coord_offset.shape[1] + res_ind_for_gbt = torch.tensor( + [ + i * max_n_blocks + j + for i, one_pose_blts in enumerate(task.blts) + for j, blt in enumerate(one_pose_blts) + for _ in blt.considered_block_types + ], + dtype=torch.int64, + device=poses.device, + ) + pose_ind_for_gbt = torch.floor_divide(res_ind_for_gbt, max_n_blocks).to(torch.int64) + + gbt_for_samplers_rots = gbt_for_rot[conf_inds_for_sampler] + res_ind_for_samplers_rots = res_ind_for_gbt[gbt_for_samplers_rots] + block_type_ind_for_samplers_rots = block_type_ind_for_rot[conf_inds_for_sampler] + + # find the number of atoms for each rotamer / orig_res + orig_res_n_atoms = pbt.n_atoms[block_type_ind_for_samplers_rots] + + # now lets note which atoms are real + dummy_rotamer_atom_inds = ( + torch.arange(pbt.max_n_atoms, dtype=torch.int64, device=poses.device) + .view(1, pbt.max_n_atoms) + .expand(n_rots_for_sampler, -1) + ) + atom_is_real_for_rot = dummy_rotamer_atom_inds < orig_res_n_atoms.unsqueeze( + 1 + ).expand(n_rots_for_sampler, pbt.max_n_atoms) + orig_atom_inds = ( + ( + # poses.block_coord_offset64.view(-1)[res_ind_for_samplers_rots] + + orig_dof_atom_offset[poses_res_to_real_poses_res[res_ind_for_samplers_rots]] + # + poses.coord_offset_for_pose[pose_ind_for_gbt[gbt_for_samplers_rots]] * poses.max_n_pose_atoms + ) + .unsqueeze(1) + .expand(-1, pbt.max_n_atoms) + + dummy_rotamer_atom_inds + )[atom_is_real_for_rot] + + rot_atom_inds = ( + n_dof_atoms_offset_for_rot[conf_inds_for_sampler] + .unsqueeze(1) + .expand(-1, pbt.max_n_atoms) + + dummy_rotamer_atom_inds + )[atom_is_real_for_rot] + return rot_atom_inds, orig_atom_inds diff --git a/tmol/pack/rotamer/single_residue_kinforest.py b/tmol/pack/rotamer/single_residue_kinforest.py index b923a15a1..a1d9ee48e 100644 --- a/tmol/pack/rotamer/single_residue_kinforest.py +++ b/tmol/pack/rotamer/single_residue_kinforest.py @@ -70,6 +70,8 @@ def construct_single_residue_kinforest(restype: RefinedResidueType): if hasattr(restype, "rotamer_kinforest"): return + # is_focused_rrt = restype.name == "PRO" + torsion_pairs = numpy.array( [uaids[1:3] for tor, uaids in restype.torsion_to_uaids.items()] ) @@ -93,6 +95,14 @@ def construct_single_residue_kinforest(restype: RefinedResidueType): ) .kinforest ) + # if is_focused_rrt: + # print("PRO") + # print("id\n", kinforest.id) + # print("parent\n", kinforest.parent) + # print("frame_x\n", kinforest.frame_x) + # print("frame_y\n", kinforest.frame_y) + # print("frame_z\n", kinforest.frame_z) + else: # print("bonds") # print(restype.bond_indices.shape) diff --git a/tmol/pack/simulated_annealing.py b/tmol/pack/simulated_annealing.py index 1e7f8bee9..44013f5a0 100644 --- a/tmol/pack/simulated_annealing.py +++ b/tmol/pack/simulated_annealing.py @@ -8,15 +8,16 @@ def run_simulated_annealing( energy_tables: PackerEnergyTables, ): return pack_anneal( + energy_tables.max_n_rotamers_per_pose, + energy_tables.pose_n_res, + energy_tables.pose_n_rotamers, + energy_tables.pose_rotamer_offset, energy_tables.nrotamers_for_res, energy_tables.oneb_offsets, energy_tables.res_for_rot, - energy_tables.respair_nenergies, energy_tables.chunk_size, energy_tables.chunk_offset_offsets, - energy_tables.twob_offsets, - energy_tables.fine_chunk_offsets, + energy_tables.chunk_offsets, energy_tables.energy1b, energy_tables.energy2b, - 0, ) diff --git a/tmol/score/common/scoring_module.py b/tmol/score/common/scoring_module.py index d4ca0eff5..78666ff1a 100644 --- a/tmol/score/common/scoring_module.py +++ b/tmol/score/common/scoring_module.py @@ -4,13 +4,13 @@ from tmol.score.common.convert_float64 import convert_float64 -class ScoringModule(torch.nn.Module): +class TermScoringModule(torch.nn.Module): def __init__( self, term_parameters, term_score_poses, ): - super(ScoringModule, self).__init__() + super(TermScoringModule, self).__init__() self.term_parameters = [] @@ -39,14 +39,14 @@ def _p(t): table += [_p(param) if type(param) is torch.Tensor else param] -class PoseScoringModule(ScoringModule): +class TermPoseScoringModule(TermScoringModule): def __init__( self, pose_stack, term_parameters, term_score_poses, ): - super(PoseScoringModule, self).__init__(term_parameters, term_score_poses) + super(TermPoseScoringModule, self).__init__(term_parameters, term_score_poses) self.common_parameters = [] @@ -56,7 +56,7 @@ def __init__( pose_stack.rot_coord_offset, pose_stack.pose_ind_for_atom, pose_stack.first_rot_for_block, - pose_stack.first_rot_for_block, + pose_stack.block_type_ind, # block_type for first rot for block pose_stack.block_ind_for_rot, pose_stack.pose_ind_for_rot, pose_stack.block_type_ind_for_rot, @@ -69,7 +69,7 @@ def __init__( ) -class WholePoseScoringModule(PoseScoringModule): +class TermWholePoseScoringModule(TermPoseScoringModule): def forward( self, coords, @@ -79,7 +79,7 @@ def forward( return scores -class BlockPairScoringModule(PoseScoringModule): +class TermBlockPairScoringModule(TermPoseScoringModule): def forward( self, coords, @@ -96,32 +96,38 @@ def forward( return sparse_result -class RotamerScoringModule(ScoringModule): +class TermRotamerScoringModule(TermScoringModule): def __init__( self, rotamer_set, term_parameters, term_score_poses, ): - super(ScoringModule, self).__init__(term_parameters, term_score_poses) + super(TermRotamerScoringModule, self).__init__( + term_parameters, term_score_poses + ) self.common_parameters = [] + def _i32(x): + return x if isinstance(x, int) else x.to(torch.int32) + self.add_parameters( self.common_parameters, [ - i.to(torch.int32) - for i in [ - rotamer_set.rot_coord_offset, - rotamer_set.first_rot_for_block, - rotamer_set.first_rot_for_block, + _i32(t) + for t in [ + rotamer_set.coord_offset_for_rot, # rot coord offset + rotamer_set.pose_ind_for_atom, # pose_ind_for_atom?? unused + rotamer_set.rot_offset_for_block, # first rot for block + rotamer_set.first_rot_block_type, # first rot block type rotamer_set.block_ind_for_rot, - rotamer_set.pose_ind_for_rot, + rotamer_set.pose_for_rot, rotamer_set.block_type_ind_for_rot, rotamer_set.n_rots_for_pose, rotamer_set.rot_offset_for_pose, rotamer_set.n_rots_for_block, - rotamer_set.rot_offset_for_block, + rotamer_set.rot_offset_for_block, # three times?! rotamer_set.max_n_rots_per_pose, ] ], @@ -133,4 +139,10 @@ def forward( ): scores, indices = self.term_score_poses(*self.format_arguments(coords, True)) - return scores + sparse_result = torch.stack( + [ + torch.sparse_coo_tensor(indices, scores[subterm, :]) + for subterm in range(scores.size(0)) + ] + ) + return sparse_result diff --git a/tmol/score/common/sphere_overlap.impl.hh b/tmol/score/common/sphere_overlap.impl.hh index ca5001e47..4093f0db1 100644 --- a/tmol/score/common/sphere_overlap.impl.hh +++ b/tmol/score/common/sphere_overlap.impl.hh @@ -231,8 +231,8 @@ struct detect_rot_neighbors { int const rot_ind1 = rot_pair_ind / max_n_rots; int const rot_ind2 = rot_pair_ind % max_n_rots; - if (rot_ind1 > n_rots_for_pose[pose_ind] - || rot_ind2 > n_rots_for_pose[pose_ind]) { + if (rot_ind1 >= n_rots_for_pose[pose_ind] + || rot_ind2 >= n_rots_for_pose[pose_ind]) { return; } diff --git a/tmol/score/elec/elec_energy_term.py b/tmol/score/elec/elec_energy_term.py index e4468d2bb..c495fd98f 100644 --- a/tmol/score/elec/elec_energy_term.py +++ b/tmol/score/elec/elec_energy_term.py @@ -6,6 +6,7 @@ from tmol.database import ParameterDatabase from tmol.score.elec.params import ElecParamResolver, ElecGlobalParams from tmol.score.elec.elec_whole_pose_module import ElecWholePoseScoringModule +from tmol.score.elec.potentials.compiled import elec_pose_scores from tmol.chemical.restypes import RefinedResidueType from tmol.pose.packed_block_types import PackedBlockTypes @@ -121,18 +122,34 @@ def _tf(arr): def setup_poses(self, poses: PoseStack): super(ElecEnergyTerm, self).setup_poses(poses) - def render_whole_pose_scoring_module(self, pose_stack: PoseStack): - pbt = pose_stack.packed_block_types - return ElecWholePoseScoringModule( - pose_stack_block_coord_offset=pose_stack.block_coord_offset, - pose_stack_block_types=pose_stack.block_type_ind, - pose_stack_min_block_bondsep=pose_stack.min_block_bondsep, - pose_stack_inter_block_bondsep=pose_stack.inter_block_bondsep, - bt_n_atoms=pbt.n_atoms, - bt_partial_charge=pbt.elec_partial_charge, - bt_n_interblock_bonds=pbt.n_conn, - bt_atoms_forming_chemical_bonds=pbt.conn_atom, - bt_inter_repr_path_distance=pbt.elec_inter_repr_path_distance, - bt_intra_repr_path_distance=pbt.elec_intra_repr_path_distance, - global_params=self.global_params, - ) + + def get_score_term_function(self): + return elec_pose_scores + + def get_score_term_attributes(self, pose_stack): + def _t(ts): + return tuple(map(lambda t: t.to(torch.float), ts)) + + global_params = torch.tensor( + [ + self.global_params.elec_sigmoidal_die_D, + self.global_params.elec_sigmoidal_die_D0, + self.global_params.elec_sigmoidal_die_S, + self.global_params.elec_min_dis, + self.global_params.elec_max_dis, + ], + dtype=torch.float32, + device=pose_stack.device, + )[None, :] + + return [ + pose_stack.min_block_bondsep, + pose_stack.inter_block_bondsep, + pose_stack.packed_block_types.n_atoms, + pose_stack.packed_block_types.elec_partial_charge, + pose_stack.packed_block_types.n_conn, + pose_stack.packed_block_types.conn_atom, + pose_stack.packed_block_types.elec_inter_repr_path_distance, + pose_stack.packed_block_types.elec_intra_repr_path_distance, + global_params, + ] \ No newline at end of file diff --git a/tmol/score/elec/potentials/compiled.ops.cpp b/tmol/score/elec/potentials/compiled.ops.cpp index 0dad0d6b0..c32a1826e 100644 --- a/tmol/score/elec/potentials/compiled.ops.cpp +++ b/tmol/score/elec/potentials/compiled.ops.cpp @@ -24,12 +24,24 @@ template