diff --git a/caduceus_distill/distill.py b/caduceus_distill/distill.py index d67e86d..d65ec98 100644 --- a/caduceus_distill/distill.py +++ b/caduceus_distill/distill.py @@ -144,6 +144,7 @@ def __init__( self.zarr_path = zarr_path self.ds: xr.Dataset | None = None self.skip_batches = skip_batches if skip_batches is not None else set() + self._in_memory_cache: dict[int, EXAMPLE_T] = {} with self.__maybe_open_zarr() as temp_ds: total_samples = len(temp_ds.sample) @@ -178,7 +179,18 @@ def __maybe_open_zarr(self) -> xr.Dataset: def __len__(self) -> int: return self._len + def warmup(self, n: int) -> None: + """ + Pre-load the first `n` samples into memory to speed up fetching, this is useful for + intermediate validation with limited number of batches. + """ + for i in range(n): + self._in_memory_cache[i] = self[i] + def __getitem__(self, idx: int) -> EXAMPLE_T: + if idx in self._in_memory_cache: + return self._in_memory_cache[idx] + # TODO: better doc, is there an idiomatic way to skip a batch? if idx in self.skip_batches: new_idx = np.random.randint(0, len(self)) @@ -533,6 +545,8 @@ def main( # NOTE: skip due to https://github.com/Open-Athena/caduceus-distill/issues/38 train_dataset = DistillationDataset(zarr_path_train, skip_batches={8590}) val_dataset = DistillationDataset(zarr_path_val) + # TODO: is this always safe to call? What if `max_val_batches` is large? + val_dataset.warmup(max_val_batches) # Create data loaders train_loader, val_loader, test_loader = [