Skip to content

Commit 7fdb313

Browse files
authored
Implement tensornet warp kernels (#384)
* implement tensornet warp ops Implements tensornet warp kernels copied from materialyzeai/matgl#709 as originally implemented by @zubatyuk. To work with the warp kernels the tensornet code has been refactored to use shapes [N,3,3,F] instead of the original [N,F,3,3]. This change required reshaping of weights from models trained by previous code. Older checkpoints are currently auto-detected using the presence of the check-errors flag which was removed in a recent commit. The loading method can also be set with a new compatibility_load=True|False flag. If the warp kernels fail to load the pure torch functions will be used. These have been refactored to match the call signatures and shapes of the warp kernels. The speedup of the warp kernels is approximately 3x for inference and training * pass test_model
1 parent a0aa111 commit 7fdb313

31 files changed

+5303
-224
lines changed

README.md

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,43 @@ Run `torchmd-train --help` to see all available options and their descriptions.
6565

6666
### AceFF
6767
Trained [AceFF models](https://huggingface.co/collections/Acellera/aceff-machine-learning-potentials) can be loaded and used for inference.
68-
please see [here](https://github.com/torchmd/torchmd-net/tree/main/examples/aceff_examples)
68+
Please see [here](https://github.com/torchmd/torchmd-net/tree/main/examples/aceff_examples) for example scripts.
69+
70+
#### Loading AceFF models with `load_model`
71+
72+
```python
73+
from huggingface_hub import hf_hub_download
74+
from torchmdnet.models.model import load_model
75+
76+
model_file_path = hf_hub_download(repo_id="Acellera/AceFF-1.1", filename="aceff_v1.1.ckpt")
77+
model = load_model(model_file_path, derivative=True)
78+
```
79+
80+
#### Loading AceFF models with the ASE calculator
81+
82+
```python
83+
from huggingface_hub import hf_hub_download
84+
from torchmdnet.calculators import TMDNETCalculator
85+
86+
model_file_path = hf_hub_download(repo_id="Acellera/AceFF-1.1", filename="aceff_v1.1.ckpt")
87+
calc = TMDNETCalculator(model_file_path, device="cuda")
88+
```
89+
90+
#### `compatibility_load` flag
91+
92+
TensorNet and TensorNet2 checkpoints trained with older versions of the code used a different
93+
internal tensor layout (`[N, F, 3, 3]` instead of the current `[N, 3, 3, F]`). When loading
94+
such a checkpoint, the affected weight matrices must be remapped before the state dict can be
95+
applied.
96+
97+
**This is handled automatically.** Old-format checkpoints always contain a `check_errors`
98+
key in their saved hyper-parameters (a parameter that was removed in newer code); `load_model`
99+
detects this and applies the remapping transparently, emitting a `UserWarning` to let you know.
100+
All currently released AceFF checkpoints (1.0, 1.1, 2.0) are old-format and are handled this way.
101+
102+
If you need to override the automatic detection you can pass `compatibility_load=True` (force
103+
remap) or `compatibility_load=False` (suppress remap) explicitly to either `load_model` or
104+
`TMDNETCalculator`.
69105

70106

71107
To load your own trained models see [here](https://github.com/torchmd/torchmd-net/tree/main/examples#loading-checkpoints) for instructions on how to load pretrained models.

benchmarks/inference.py

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ def load_example_args(model_name, remove_prior=False, config_file=None, **kwargs
2525
config_file = join(
2626
dirname(dirname(__file__)), "examples", "TensorNet-QM9.yaml"
2727
)
28+
elif model_name == "tensornet2":
29+
config_file = join(
30+
dirname(dirname(__file__)), "examples", "TensorNet2-QM9.yaml"
31+
)
2832
else:
2933
config_file = join(dirname(dirname(__file__)), "examples", "ET-QM9.yaml")
3034
with open(config_file, "r") as f:
@@ -64,22 +68,24 @@ def benchmark_pdb(pdb_file, **kwargs):
6468
molecule = None
6569
torch.cuda.nvtx.range_push("Initialization")
6670
args = load_example_args(
67-
"tensornet",
68-
config_file="../examples/TensorNet-rMD17.yaml",
71+
kwargs["model"],
6972
remove_prior=True,
7073
output_model="Scalar",
7174
derivative=False,
7275
max_z=int(atomic_numbers.max() + 1),
73-
max_num_neighbors=32,
76+
max_num_neighbors=64,
7477
**kwargs,
7578
)
7679
model = create_model(args)
7780
z = atomic_numbers
7881
pos = positions
7982
batch = torch.zeros_like(z).to("cuda")
83+
model.representation_model.setup_for_inference(
84+
z.cpu(), batch.cpu()
85+
) # setup for inference
8086
model = model.to("cuda")
81-
torch.cuda.nvtx.range_pop()
82-
torch.cuda.nvtx.range_push("Warmup")
87+
# torch.cuda.nvtx.range_pop()
88+
# torch.cuda.nvtx.range_push("Warmup")
8389
for i in range(3):
8490
pred, _ = model(z, pos, batch)
8591
pred.sum().backward()
@@ -88,38 +94,27 @@ def benchmark_pdb(pdb_file, **kwargs):
8894
for i in range(10):
8995
pred, _ = model(z, pos, batch)
9096
pred.sum().backward()
91-
torch.cuda.nvtx.range_pop()
92-
torch.cuda.nvtx.range_push("Benchmark")
97+
# torch.cuda.nvtx.range_pop()
98+
# torch.cuda.nvtx.range_push("Benchmark")
9399
nbench = 100
94-
times = np.zeros(nbench)
95-
stream = torch.cuda.Stream()
96100
torch.cuda.synchronize()
97-
with GpuTimer() as timer:
98-
with torch.cuda.stream(stream):
99-
for i in range(nbench):
100-
# torch.cuda.synchronize()
101-
# with GpuTimer() as timer2:
102-
# torch.cuda.nvtx.range_push("Step")
103-
pred, _ = model(z, pos, batch)
104-
# torch.cuda.nvtx.range_push("derivative")
105-
pred.sum().backward()
106-
# torch.cuda.nvtx.range_pop()
107-
# torch.cuda.nvtx.range_pop()
108-
# torch.cuda.synchronize()
109-
# times[i] = timer2.interval
110-
torch.cuda.synchronize()
111-
# torch.cuda.nvtx.range_pop()
112-
return len(atomic_numbers), timer.interval / nbench
101+
t1 = time.perf_counter()
102+
for i in range(nbench):
103+
pred, _ = model(z, pos, batch)
104+
pred.sum().backward()
105+
torch.cuda.synchronize()
106+
t2 = time.perf_counter()
107+
return len(atomic_numbers), (t2 - t1) * 1000 / nbench
113108

114109

115110
from tabulate import tabulate
116111

117112
# List of cases to benchmark, arbitrary parameters can be overriden here
118113
cases = {
119-
"0L": {"num_layers": 0, "embedding_dimension": 128},
120-
"1L": {"num_layers": 1, "embedding_dimension": 128},
121-
"2L": {"num_layers": 2, "embedding_dimension": 128},
122-
"2L emb 64": {"num_layers": 2, "embedding_dimension": 64},
114+
"0L": {"model": "tensornet", "num_layers": 0, "embedding_dimension": 128},
115+
"1L": {"model": "tensornet", "num_layers": 1, "embedding_dimension": 128},
116+
"2L": {"model": "tensornet", "num_layers": 2, "embedding_dimension": 128},
117+
"2L emb 64": {"model": "tensornet", "num_layers": 2, "embedding_dimension": 64},
123118
}
124119

125120

@@ -134,8 +129,6 @@ def benchmark_all():
134129
for pdb_file in os.listdir("systems"):
135130
if not pdb_file.endswith(".pdb"):
136131
continue
137-
if pdb_file == "stmv.pdb": # Does not fit in a 4090
138-
continue
139132
times = {}
140133
num_atoms = 0
141134
for name, kwargs in cases.items():

examples/TensorNet-QM9.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ ngpus: -1
3636
num_epochs: 3000
3737
num_layers: 3
3838
num_nodes: 1
39-
num_rbf: 64
39+
num_rbf: 32
4040
num_workers: 6
4141
output_model: Scalar
4242
precision: 32
@@ -57,3 +57,5 @@ weight_decay: 0.0
5757
box_vecs: null
5858
charge: false
5959
spin: false
60+
static_shapes: false
61+
check_errors: true

examples/aceff_examples/ase_aceff.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717
# We create the ASE calculator by supplying the path to the model and specifying the device and dtype
18-
calc = TMDNETCalculator(model_file_path, device="cuda")
18+
calc = TMDNETCalculator(model_file_path, device="cuda", max_num_neighbors=24)
1919
atoms = read("caffeine.pdb")
2020
print(atoms)
2121

@@ -77,6 +77,7 @@
7777
atoms.calc = calc
7878

7979
# Run more dynamics
80+
dyn.run(steps=10) # warmup before timing
8081
t1 = time.perf_counter()
8182
dyn.run(steps=nsteps)
8283
t2 = time.perf_counter()

examples/aceff_examples/ase_aceff_PBC.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,14 @@
1515

1616
# We create the ASE calculator by supplying the path to the model and specifying the device and dtype
1717
# we provided a cutoff for the coulomb term so we can use PBCs
18-
calc = TMDNETCalculator(model_file_path, device="cuda", coulomb_cutoff=10.0)
18+
calc = TMDNETCalculator(
19+
model_file_path,
20+
device="cuda",
21+
coulomb_cutoff=10.0,
22+
)
1923
atoms = read("alanine-dipeptide-explicit.pdb")
2024

25+
2126
print(atoms)
2227

2328
atoms.calc = calc
@@ -39,7 +44,7 @@
3944

4045
# setup MD
4146
temperature_K: float = 300
42-
timestep: float = 1.0 * units.fs
47+
timestep: float = 0.5 * units.fs
4348
friction: float = 0.01 / units.fs
4449
traj_interval: int = 10
4550
log_interval: int = 10
@@ -54,5 +59,3 @@
5459
t1 = time.perf_counter()
5560
dyn.run(steps=nsteps)
5661
t2 = time.perf_counter()
57-
58-
print(f"Completed MD in {t2 - t1:.1f} s ({(t2 - t1)*1000 / nsteps:.3f} ms/step)")

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ dependencies = [
2121
"triton; sys_platform == 'linux' and platform_machine != 'aarch64'",
2222
"triton-windows; sys_platform == 'win32'",
2323
"ase",
24+
"warp-lang>=1.10.1",
2425
"setuptools>=82.0.0",
2526
]
2627

tests/expected.pkl

2.2 KB
Binary file not shown.

tests/test_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def test_forward(model_name, use_batch, explicit_q_s, precision):
3838
@mark.parametrize("precision", [32, 64])
3939
def test_forward_output_modules(model_name, output_model, precision):
4040
z, pos, batch = create_example_batch()
41+
pos = pos.to(dtype=dtype_mapping[precision])
4142
args = load_example_args(
4243
model_name, remove_prior=True, output_model=output_model, precision=precision
4344
)

tests/test_warp_ops.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""Tests that warp-ops and pure-Python TensorNet paths produce identical results."""
2+
3+
import os
4+
import pytest
5+
import torch
6+
from torch.testing import assert_close
7+
from os.path import dirname, join
8+
9+
import torchmdnet.models.tensornet as _tn
10+
11+
CURR_DIR = dirname(__file__)
12+
CKPT = join(CURR_DIR, "example_tensornet.ckpt")
13+
CAFFEINE_PDB = join(CURR_DIR, "caffeine.pdb")
14+
15+
16+
# ---------------------------------------------------------------------------
17+
# Helpers
18+
# ---------------------------------------------------------------------------
19+
20+
21+
def _load_model(device):
22+
from torchmdnet.models.model import load_model
23+
24+
return load_model(CKPT, derivative=True).to(device)
25+
26+
27+
def _caffeine_tensors(device):
28+
from ase.io import read
29+
from ase.data import atomic_numbers
30+
31+
atoms = read(CAFFEINE_PDB)
32+
z = torch.tensor(
33+
[atomic_numbers[s] for s in atoms.get_chemical_symbols()], dtype=torch.long
34+
).to(device)
35+
pos = torch.tensor(atoms.get_positions(), dtype=torch.float32).to(device)
36+
return z, pos
37+
38+
39+
def _run(model, z, pos):
40+
energy, forces = model(z, pos)
41+
return energy.detach(), forces.detach()
42+
43+
44+
def _set_opt(model, value: bool):
45+
"""Set .opt on the TensorNet representation model and all its submodules."""
46+
rep = model.representation_model
47+
rep.opt = value
48+
rep.tensor_embedding.opt = value
49+
for layer in rep.layers:
50+
layer.opt = value
51+
52+
53+
def _patch_nonopt(monkeypatch, model):
54+
"""Switch model to pure-Python ops by patching module-level ops and .opt flags."""
55+
# Module-level ops are used as globals inside forward() bodies, so they
56+
# still need to be swapped even though branching is now done via self.opt.
57+
monkeypatch.setattr(_tn, "compose_tensor", _tn._compose_tensor)
58+
monkeypatch.setattr(_tn, "decompose_tensor", _tn._decompose_tensor)
59+
monkeypatch.setattr(_tn, "tensor_matmul_o3", _tn._tensor_matmul_o3)
60+
monkeypatch.setattr(_tn, "tensor_matmul_so3", _tn._tensor_matmul_so3)
61+
_set_opt(model, False)
62+
63+
64+
# ---------------------------------------------------------------------------
65+
# Tests
66+
# ---------------------------------------------------------------------------
67+
68+
69+
@pytest.mark.parametrize("device", ["cpu", "cuda"])
70+
def test_warp_vs_python(device, monkeypatch):
71+
"""Warp-ops and pure-Python paths must produce identical energy and forces."""
72+
if device == "cuda" and not torch.cuda.is_available():
73+
pytest.skip("CUDA not available")
74+
if not _tn.OPT:
75+
pytest.skip("warp-ops not available")
76+
77+
model = _load_model(device)
78+
z, pos = _caffeine_tensors(device)
79+
80+
energy_opt, forces_opt = _run(model, z, pos)
81+
82+
_patch_nonopt(monkeypatch, model)
83+
energy_py, forces_py = _run(model, z, pos)
84+
85+
assert_close(energy_opt, energy_py, rtol=1e-4, atol=1e-4)
86+
assert_close(forces_opt, forces_py, rtol=1e-4, atol=1e-4)
87+
88+
89+
@pytest.mark.parametrize("device", ["cpu", "cuda"])
90+
def test_nonopt_runs(device, monkeypatch):
91+
"""Pure-Python (opt=False) path must produce finite energy and forces."""
92+
if device == "cuda" and not torch.cuda.is_available():
93+
pytest.skip("CUDA not available")
94+
95+
model = _load_model(device)
96+
z, pos = _caffeine_tensors(device)
97+
_patch_nonopt(monkeypatch, model)
98+
99+
energy, forces = _run(model, z, pos)
100+
101+
assert torch.isfinite(energy).all(), "Energy contains non-finite values"
102+
assert torch.isfinite(forces).all(), "Forces contain non-finite values"
103+
assert forces.shape == pos.shape
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# 1. Redistributions of source code must retain the above copyright notice, this
8+
# list of conditions and the following disclaimer.
9+
#
10+
# 2. Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
#
14+
# 3. Neither the name of the copyright holder nor the names of its
15+
# contributors may be used to endorse or promote products derived from
16+
# this software without specific prior written permission.
17+
#
18+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28+
"""Warp GPU kernels for TensorNet operations."""
29+
30+
from __future__ import annotations
31+
32+
import warp as wp
33+
34+
from .compose_tensor import generate_compose_tensor
35+
from .decompose_tensor import generate_decompose_tensor
36+
from .equivariant_o3_matmul import generate_tensor_matmul_o3_3x3
37+
from .equivariant_so3_matmul import generate_tensor_matmul_so3_3x3
38+
from .graph_transform import convert_to_sparse, count_row_col
39+
from .tensor_norm3 import generate_tensor_norm3
40+
from .tensornet_mp import generate_message_passing
41+
from .tensornet_radial_mp import generate_radial_message_passing
42+
from .utils import add_module, get_module, get_stream
43+
44+
wp.init()
45+
46+
47+
__all__ = [
48+
"add_module",
49+
"add_module",
50+
"convert_to_sparse",
51+
"convert_to_sparse",
52+
"count_row_col",
53+
"count_row_col",
54+
"generate_compose_tensor",
55+
"generate_decompose_tensor",
56+
"generate_message_passing",
57+
"generate_message_passing",
58+
"generate_radial_message_passing",
59+
"generate_radial_message_passing",
60+
"generate_tensor_matmul_o3_3x3",
61+
"generate_tensor_matmul_so3_3x3",
62+
"generate_tensor_norm3",
63+
"get_module",
64+
"get_stream",
65+
]

0 commit comments

Comments
 (0)