Skip to content

pixtral SFT #296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions fast_llm/data/dataset/gpt/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,10 @@ def _init(
if self._version >= 3:
self._has_preference_spans = struct.unpack("<B", stream.read(1))[0]

if self._version >= 3:
self._has_preference_spans = struct.unpack("<B", stream.read(1))[0]

if self._version >= 4:
self._has_images = struct.unpack("<B", stream.read(1))[0]
# not sure of assignment, reading flag to indicate whether preference loss-masking spans are present
self._has_preference_spans = struct.unpack("<B", stream.read(1))[0]
Comment on lines +61 to +62
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's already read above, why read it here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's another flag written after images, this is to read that, but not sure what the assignment should be

# Placeholder flag for preference spans
idx_stream.write(struct.pack("<B", 0))
# Flag to indicate whether images are present
idx_stream.write(struct.pack("<B", 1 if total_images > 0 else 0))
# Flag to indicate whether preference loss-masking spans are present
idx_stream.write(struct.pack("<B", 1 if chosen_spans.size > 0 and rejected_spans.size > 0 else 0))

Copy link
Member

@sohamparikh sohamparikh Jun 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah that order should be flipped, the chosen_spans byte should be before total_images, i'll fix it in my branch

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would break files with version==3 right?

It seems to me that we should rather fix the order in which those flags are dumped in the idx file below

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RaymondLi0 yes, I'm planning to fix it in #227


self._dtype = MEMMAP_DTYPES[struct.unpack("<B", stream.read(1))[0]].numpy
self._num_documents = struct.unpack("<Q", stream.read(8))[0]
Expand Down Expand Up @@ -112,7 +111,6 @@ def _init(
offset += (
self._num_spans.nbytes
+ self._num_spans.sum() * 2 * np.dtype(np.int32).itemsize
+ sum([x.nbytes for x in self._spans])
)
# read preference spans
self._chosen_spans = None
Expand Down Expand Up @@ -216,11 +214,12 @@ def get(
image_positions = None
if self._has_images:
image_positions = self._image_positions[idx]

# Truncations with images are not yet supported, so we get all images from the document
pixels = np.frombuffer(
self._bin_buffer,
dtype=np.dtype(np.uint8),
count=self._image_lengths[idx].prod(initial=3),
count=self._image_lengths[idx].prod(initial=3, axis=1).sum(),
offset=self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize,
)
images = []
Expand Down Expand Up @@ -357,7 +356,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
image_lengths.append(np.array(pixels.shape[1:]))
bin_stream.write(pixels.tobytes(order="C"))
total_im_size += pixels.size
im_positions.append(document.image_positions)
im_positions.extend(document.image_positions)

# Update metadata
doc_length = len(document.token_ids)
Expand Down
9 changes: 5 additions & 4 deletions fast_llm/data/dataset/gpt/sampled.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,8 @@ def __getitem__(self, index: int) -> typing.Any:
use_loss_masking_spans=self._parameters.use_loss_masking_spans,
)
start_pos = 0
if sample.image_positions:
has_images = sample.image_positions is not None
if has_image_positions:
for idx, im_position in enumerate(sample.image_positions):
# image_positions.append(im_positions + len(token_ids) + image_tokens_added)
# Add placeholders for image tokens
Expand Down Expand Up @@ -593,15 +594,15 @@ def __getitem__(self, index: int) -> typing.Any:
image_idx = 0
image_position = (
sample.image_positions[image_idx]
if image_idx < len(sample.image_positions)
if has_images and image_idx < len(sample.image_positions)
else float("inf")
)
while image_position < loss_masking_span[0]:
prev_image_tokens += image_sizes[image_idx]
image_idx += 1
image_position = (
sample.image_positions[image_idx]
if image_idx < len(sample.image_positions)
if has_images and image_idx < len(sample.image_positions)
else float("inf")
)
span_image_tokens = 0
Expand All @@ -610,7 +611,7 @@ def __getitem__(self, index: int) -> typing.Any:
image_idx += 1
image_position = (
sample.image_positions[image_idx]
if image_idx < len(sample.image_positions)
if has_images and image_idx < len(sample.image_positions)
else float("inf")
)
loss_masking_span[0] += prev_image_tokens
Expand Down
4 changes: 2 additions & 2 deletions fast_llm/data/preparator/gpt_memmap/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,13 @@ def _document_generator():
for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"):
yield GPTSample(
np.array(item["input_ids"], dtype=self._data_type.numpy),
item["images"] if self._config.dataset.images else None,
item["image_positions"] if self._config.dataset.image_positions else None,
(
np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2)
if self._config.dataset.loss_masking_spans
else None
),
item["images"] if self._config.dataset.images else None,
item["image_positions"] if self._config.dataset.image_positions else None,
item.get("chosen_token_spans", None),
item.get("rejected_token_spans", None),
)
Expand Down