Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions caduceus_distill/distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,11 @@ class DistillationDataset(Dataset[EXAMPLE_T]):
def __init__(
self,
zarr_path: str,
skip_batches: set[int] | None = None,
) -> None:
self.zarr_path = zarr_path
self.ds: xr.Dataset | None = None
self.skip_batches = skip_batches if skip_batches is not None else set()

with self.__maybe_open_zarr() as temp_ds:
total_samples = len(temp_ds.sample)
Expand Down Expand Up @@ -177,6 +179,12 @@ def __len__(self) -> int:
return self._len

def __getitem__(self, idx: int) -> EXAMPLE_T:
# TODO: better doc, is there an idiomatic way to skip a batch?
if idx in self.skip_batches:
new_idx = np.random.randint(0, len(self))
logger.debug(f"Using batch {new_idx} instead of {idx}")
idx = new_idx

self.ds = self.__maybe_open_zarr()
sample = self.ds.isel(sample=idx)
input_ids = torch.tensor(sample.input_ids.values, dtype=torch.long)
Expand Down Expand Up @@ -522,8 +530,8 @@ def main(
) -> None:
L.seed_everything(42, workers=True)

# Initialize datasets
train_dataset = DistillationDataset(zarr_path_train)
# NOTE: skip due to https://github.com/Open-Athena/caduceus-distill/issues/38
train_dataset = DistillationDataset(zarr_path_train, skip_batches={8590})
val_dataset = DistillationDataset(zarr_path_val)

# Create data loaders
Expand Down