You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* implement tensornet warp ops
Implements tensornet warp kernels copied from materialyzeai/matgl#709 as originally implemented by @zubatyuk.
To work with the warp kernels the tensornet code has been refactored to use shapes [N,3,3,F] instead of the original [N,F,3,3].
This change required reshaping of weights from models trained by previous code. Older checkpoints are currently auto-detected using the presence of the check-errors flag which was removed in a recent commit. The loading method can also be set with a new compatibility_load=True|False flag.
If the warp kernels fail to load the pure torch functions will be used. These have been refactored to match the call signatures and shapes of the warp kernels.
The speedup of the warp kernels is approximately 3x for inference and training
* pass test_model
TensorNet and TensorNet2 checkpoints trained with older versions of the code used a different
93
+
internal tensor layout (`[N, F, 3, 3]` instead of the current `[N, 3, 3, F]`). When loading
94
+
such a checkpoint, the affected weight matrices must be remapped before the state dict can be
95
+
applied.
96
+
97
+
**This is handled automatically.** Old-format checkpoints always contain a `check_errors`
98
+
key in their saved hyper-parameters (a parameter that was removed in newer code); `load_model`
99
+
detects this and applies the remapping transparently, emitting a `UserWarning` to let you know.
100
+
All currently released AceFF checkpoints (1.0, 1.1, 2.0) are old-format and are handled this way.
101
+
102
+
If you need to override the automatic detection you can pass `compatibility_load=True` (force
103
+
remap) or `compatibility_load=False` (suppress remap) explicitly to either `load_model` or
104
+
`TMDNETCalculator`.
69
105
70
106
71
107
To load your own trained models see [here](https://github.com/torchmd/torchmd-net/tree/main/examples#loading-checkpoints) for instructions on how to load pretrained models.
0 commit comments