Skip to content

Commit ad6f494

Browse files
authored
Rename ParallelMLIPPredictUnit (#1582)
* Update README.md * rename ParallelMLIPPredictUnit * update
1 parent d8c69a7 commit ad6f494

8 files changed

Lines changed: 11 additions & 11 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ If you have multiple gpus (or multiple nodes), we handle all the parallelism for
187187
pip install fairchem-core[extras]
188188
```
189189

190-
```
190+
```python
191191
from ase import units
192192
from ase.md.langevin import Langevin
193193
from fairchem.core import pretrained_mlip, FAIRChemCalculator

docs/core/common_tasks/ase_calculator.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ predictor = pretrained_mlip.get_predict_unit(
106106

107107
UMA supports Graph Parallel inference natively. The graph is chunked into each rank and both the forward and backwards communication is handled by the built-in graph parallel algorithm with torch distributed. Because Multi-GPU inference requires special setup of communication protocols within a node and across nodes, we leverage [ray](https://www.ray.io/) to launch Ray Actors for each GPU-rank under the hood. This allows us to seemlessly scale to any infrastructure that can run Ray.
108108

109-
To make things simple for the user that wants to run multi-gpu inference locally, we provide a drop-in replacement for MLIPPredictUnit, called [ParallelMLIPPredictUnitRay](https://github.com/facebookresearch/fairchem/blob/85bd83535fedbc1d99eee4c12e175603ccc44ef7/src/fairchem/core/units/mlip_unit/predict.py#L415)
109+
To make things simple for the user that wants to run multi-gpu inference locally, we provide a drop-in replacement for MLIPPredictUnit, called [ParallelMLIPPredictUnit](https://github.com/facebookresearch/fairchem/blob/85bd83535fedbc1d99eee4c12e175603ccc44ef7/src/fairchem/core/units/mlip_unit/predict.py#L415)
110110

111111
To enable this you need to install Ray manually or through the fairchem extra dependencies option
112112

docs/core/common_tasks/lammps.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ lmp_fc lmp_in="lammps_in_example.file" task_name="omol"
4747
```
4848

4949
## Multi-GPU parallelism
50-
Our LAMMPs integration is fully compatible out the box with our Multi-GPU inference API. The only change required is to pass it the ParallelMLIPPredictUnitRay [here](https://github.com/facebookresearch/fairchem/blob/main/src/fairchem/lammps/lammps_fc_config.yaml#L20) instead of the regular predict unit when initializing the lammps fairchem script. No need to install anything new such as Kokkos or add communication code.
50+
Our LAMMPs integration is fully compatible out the box with our Multi-GPU inference API. The only change required is to pass it the ParallelMLIPPredictUnit [here](https://github.com/facebookresearch/fairchem/blob/main/src/fairchem/lammps/lammps_fc_config.yaml#L20) instead of the regular predict unit when initializing the lammps fairchem script. No need to install anything new such as Kokkos or add communication code.
5151

5252
For example:
5353
```

src/fairchem/core/calculate/pretrained_mlip.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def get_predict_unit(
8585
device: Optional torch device to load the model onto. If None, uses the default device.
8686
cache_dir: Path to folder where model files will be stored. Default is "~/.cache/fairchem"
8787
workers: Number of parallel workers for prediction unit. Default is 1. If greater than 1,
88-
we will instantiate a ParallelMLIPPredictUnitRay instead of the normal predict unit.
88+
we will instantiate a ParallelMLIPPredictUnit instead of the normal predict unit.
8989
9090
Returns:
9191
An initialized MLIPPredictUnit ready for making predictions.

src/fairchem/core/units/mlip_unit/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def load_predict_unit(
4141
device: Optional torch device to load the model onto.
4242
atom_refs: Optional dictionary of isolated atom reference energies.
4343
workers: Number of parallel workers for prediction unit. Default is 1. If greater than 1,
44-
we will instantiate a ParallelMLIPPredictUnitRay instead of the normal predict unit.
44+
we will instantiate a ParallelMLIPPredictUnit instead of the normal predict unit.
4545
4646
Returns:
4747
A MLIPPredictUnit instance ready for inference
@@ -54,9 +54,9 @@ def load_predict_unit(
5454
inference_settings = guess_inference_settings(inference_settings)
5555
overrides = overrides or {"backbone": {"always_use_pbc": False}}
5656
if workers > 1:
57-
from fairchem.core.units.mlip_unit.predict import ParallelMLIPPredictUnitRay
57+
from fairchem.core.units.mlip_unit.predict import ParallelMLIPPredictUnit
5858

59-
return ParallelMLIPPredictUnitRay(
59+
return ParallelMLIPPredictUnit(
6060
path,
6161
device=device,
6262
inference_settings=inference_settings,

src/fairchem/core/units/mlip_unit/predict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ def predict(self, data: AtomicData) -> dict[str, torch.tensor] | None:
412412

413413

414414
@requires(ray_installed, message="Requires `ray` to be installed")
415-
class ParallelMLIPPredictUnitRay(MLIPPredictUnitProtocol):
415+
class ParallelMLIPPredictUnit(MLIPPredictUnitProtocol):
416416
def __init__(
417417
self,
418418
inference_model_path: str,

src/fairchem/lammps/lammps_fc_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ local_predict_unit:
1818

1919
# Use parallel predict unit for multi-gpu
2020
parallel_predict_unit:
21-
_target_: fairchem.core.units.mlip_unit.predict.ParallelMLIPPredictUnitRay
21+
_target_: fairchem.core.units.mlip_unit.predict.ParallelMLIPPredictUnit
2222
inference_model_path:
2323
_target_: fairchem.core.calculate.pretrained_mlip.pretrained_checkpoint_path_from_name
2424
model_name: "uma-s-1p1"

tests/core/units/mlip_unit/test_predict.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from fairchem.core.calculate.pretrained_mlip import pretrained_checkpoint_path_from_name
1111
from fairchem.core.datasets.atomic_data import AtomicData, atomicdata_list_to_batch
1212
from fairchem.core.units.mlip_unit.api.inference import InferenceSettings
13-
from fairchem.core.units.mlip_unit.predict import ParallelMLIPPredictUnitRay
13+
from fairchem.core.units.mlip_unit.predict import ParallelMLIPPredictUnit
1414
from tests.conftest import seed_everywhere
1515

1616
ATOL = 1e-5
@@ -144,7 +144,7 @@ def test_parallel_predict_unit(workers, device):
144144
atomic_data = AtomicData.from_ase(atoms, task_name=["omat"])
145145

146146
seed_everywhere(seed)
147-
ppunit = ParallelMLIPPredictUnitRay(
147+
ppunit = ParallelMLIPPredictUnit(
148148
inference_model_path=model_path,
149149
device=device,
150150
inference_settings=ifsets,

0 commit comments

Comments
 (0)