Skip to content

Commit

Permalink
move load calls, rename test for better readibility (#601)
Browse files Browse the repository at this point in the history
### Description
Makes the test_model use the converted nemo2 checkpoint for 650M, and
better renames the test that requires the Nemo1 checkpoint

### Type of changes
<!-- Mark the relevant option with an [x] -->

- [x]  Bug fix (non-breaking change which fixes an issue)
- [ ]  New feature (non-breaking change which adds functionality)
- [x]  Refactor
- [ ]  Documentation update
- [ ]  Other (please describe):

### CI Pipeline Configuration
Configure CI behavior by checking relevant boxes below. This will
automatically apply labels.

- [ ]
[SKIP_CI](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/user-guide/contributing/contributing.md#skip_ci)
- Skip all continuous integration tests
- [ ]
[INCLUDE_NOTEBOOKS_TESTS](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/user-guide/contributing/contributing.md#include_notebooks_tests)
- Execute notebook validation tests in pytest

> [!NOTE]
> By default, the notebooks validation tests are skipped unless
explicitly enabled.

### Usage
<!--- How does a user interact with the changed code -->
```python
TODO: Add code snippet
```

### Pre-submit Checklist
<!--- Ensure all items are completed before submitting -->

 - [x] I have tested these changes locally
 - [ ] I have updated the documentation accordingly
 - [x] I have added/updated tests as needed
 - [ ] All existing tests pass successfully

Signed-off-by: Peter St. John <[email protected]>
  • Loading branch information
pstjohn authored Jan 15, 2025
1 parent 93233ff commit f59c602
Showing 1 changed file with 2 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import io
import tarfile
from copy import deepcopy
from pathlib import Path
from typing import List, Tuple
from unittest import mock

Expand All @@ -39,9 +38,6 @@
from bionemo.testing import megatron_parallel_state_utils


nemo1_checkpoint_path: Path = load("esm2/nv_650m:1.0")


def reduce_hiddens(hiddens: Tensor, attention_mask: Tensor) -> Tensor:
"""reduce last layer's hidden values to embeddings
Expand Down Expand Up @@ -72,7 +68,7 @@ def esm2_config() -> ESM2Config:
@pytest.fixture(scope="module")
def esm2_650M_config_w_ckpt() -> ESM2Config:
with megatron_parallel_state_utils.distributed_model_parallel_state():
yield ESM2Config(nemo1_ckpt_path=nemo1_checkpoint_path)
yield ESM2Config(initial_ckpt_path=load("esm2/650m:2.0"))


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -141,7 +137,7 @@ def test_esm2_model_initialized(esm2_model):


def test_esm2_650m_checkpoint(esm2_model):
with tarfile.open(nemo1_checkpoint_path, "r") as ckpt, torch.no_grad():
with tarfile.open(load("esm2/nv_650m:1.0"), "r") as ckpt, torch.no_grad():
ckpt_file = ckpt.extractfile("./model_weights.ckpt")

old_state_dict = torch.load(ckpt_file)
Expand Down

0 comments on commit f59c602

Please sign in to comment.