Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions caduceus_distill/distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def __init__(
self.zarr_path = zarr_path
self.ds: xr.Dataset | None = None
self.skip_batches = skip_batches if skip_batches is not None else set()
self._in_memory_cache: dict[int, EXAMPLE_T] = {}

with self.__maybe_open_zarr() as temp_ds:
total_samples = len(temp_ds.sample)
Expand Down Expand Up @@ -178,7 +179,18 @@ def __maybe_open_zarr(self) -> xr.Dataset:
def __len__(self) -> int:
return self._len

def warmup(self, n: int) -> None:
"""
Pre-load the first `n` samples into memory to speed up fetching, this is useful for
intermediate validation with limited number of batches.
"""
for i in range(n):
self._in_memory_cache[i] = self[i]

def __getitem__(self, idx: int) -> EXAMPLE_T:
if idx in self._in_memory_cache:
return self._in_memory_cache[idx]

# TODO: better doc, is there an idiomatic way to skip a batch?
if idx in self.skip_batches:
new_idx = np.random.randint(0, len(self))
Expand Down Expand Up @@ -533,6 +545,8 @@ def main(
# NOTE: skip due to https://github.com/Open-Athena/caduceus-distill/issues/38
train_dataset = DistillationDataset(zarr_path_train, skip_batches={8590})
val_dataset = DistillationDataset(zarr_path_val)
# TODO: is this always safe to call? What if `max_val_batches` is large?
val_dataset.warmup(max_val_batches)

# Create data loaders
train_loader, val_loader, test_loader = [
Expand Down