diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 37241524..acc7914f 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -56,11 +56,10 @@ def _init( if self._version >= 3: self._has_preference_spans = struct.unpack("= 3: - self._has_preference_spans = struct.unpack("= 4: self._has_images = struct.unpack(" typing.Any: use_loss_masking_spans=self._parameters.use_loss_masking_spans, ) start_pos = 0 - if sample.image_positions: + has_images = sample.image_positions is not None + if has_image_positions: for idx, im_position in enumerate(sample.image_positions): # image_positions.append(im_positions + len(token_ids) + image_tokens_added) # Add placeholders for image tokens @@ -593,7 +594,7 @@ def __getitem__(self, index: int) -> typing.Any: image_idx = 0 image_position = ( sample.image_positions[image_idx] - if image_idx < len(sample.image_positions) + if has_images and image_idx < len(sample.image_positions) else float("inf") ) while image_position < loss_masking_span[0]: @@ -601,7 +602,7 @@ def __getitem__(self, index: int) -> typing.Any: image_idx += 1 image_position = ( sample.image_positions[image_idx] - if image_idx < len(sample.image_positions) + if has_images and image_idx < len(sample.image_positions) else float("inf") ) span_image_tokens = 0 @@ -610,7 +611,7 @@ def __getitem__(self, index: int) -> typing.Any: image_idx += 1 image_position = ( sample.image_positions[image_idx] - if image_idx < len(sample.image_positions) + if has_images and image_idx < len(sample.image_positions) else float("inf") ) loss_masking_span[0] += prev_image_tokens diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index ad3dd449..0b680310 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -158,13 +158,13 @@ def _document_generator(): for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): yield GPTSample( np.array(item["input_ids"], dtype=self._data_type.numpy), + item["images"] if self._config.dataset.images else None, + item["image_positions"] if self._config.dataset.image_positions else None, ( np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2) if self._config.dataset.loss_masking_spans else None ), - item["images"] if self._config.dataset.images else None, - item["image_positions"] if self._config.dataset.image_positions else None, item.get("chosen_token_spans", None), item.get("rejected_token_spans", None), )