Skip to content

Commit 4382dab

Browse files
committed
add fixes for mm sft
1 parent 9e4f14f commit 4382dab

File tree

3 files changed

+89
-47
lines changed

3 files changed

+89
-47
lines changed

fast_llm/data/dataset/gpt/memmap.py

Lines changed: 86 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _init(
4646
self._has_spans = 0
4747
self._has_images = 0
4848
self._has_preference_spans = False
49-
49+
5050
with self._prefix.with_suffix(".idx").open("rb") as stream:
5151
Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}")
5252
self._version = struct.unpack("<Q", stream.read(8))[0]
@@ -55,14 +55,12 @@ def _init(
5555
self._has_spans = struct.unpack("<B", stream.read(1))[0]
5656
if self._version >= 3:
5757
self._has_preference_spans = struct.unpack("<B", stream.read(1))[0]
58-
59-
if self._version >= 3:
60-
self._has_preference_spans = struct.unpack("<B", stream.read(1))[0]
61-
6258
if self._version >= 4:
6359
self._has_images = struct.unpack("<B", stream.read(1))[0]
60+
# not sure of assignment, but has to read something here
61+
self._has_preference_spans = struct.unpack("<B", stream.read(1))[0]
6462

65-
self._dtype = MEMMAP_DTYPES[struct.unpack("<B", stream.read(1))[0]].numpy
63+
self._dtype = MEMMAP_DTYPES[struct.unpack("<B", stream.read(1))[0]].numpy
6664
self._num_documents = struct.unpack("<Q", stream.read(8))[0]
6765
_ = struct.unpack("<Q", stream.read(8))[0]
6866
offset = stream.tell()
@@ -112,8 +110,7 @@ def _init(
112110
offset += (
113111
self._num_spans.nbytes
114112
+ self._num_spans.sum() * 2 * np.dtype(np.int32).itemsize
115-
+ sum([x.nbytes for x in self._spans])
116-
)
113+
)
117114
# read preference spans
118115
self._chosen_spans = None
119116
self._rejected_spans = None
@@ -146,34 +143,58 @@ def _init(
146143
self._image_lengths = None
147144
self._image_positions = None
148145
if self._has_images and self._version >= 4:
146+
# Read number of images per document
149147
self._n_images = np.frombuffer(
150148
self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset
151149
)
152-
self._image_lengths = []
153-
self._image_positions = []
154-
images_seen = 0
155-
for n_images in self._n_images:
156-
self._image_lengths.append(
157-
np.frombuffer(
150+
offset += self._n_images.nbytes
151+
# Read image dimensions
152+
total_images = self._n_images.sum()
153+
if total_images > 0:
154+
image_lengths_flat = np.frombuffer(
155+
self._index_bin_buffer,
156+
dtype=np.int32,
157+
count=total_images * 2,
158+
offset=offset
159+
).reshape(-1, 2)
160+
offset += image_lengths_flat.nbytes
161+
162+
# Split image lengths by document
163+
self._image_lengths = []
164+
img_start = 0
165+
for n_images in self._n_images:
166+
if n_images > 0:
167+
self._image_lengths.append(image_lengths_flat[img_start:img_start + n_images])
168+
self._num_pixels += self._image_lengths[-1].prod(axis=1, initial=3).sum()
169+
img_start += n_images
170+
else:
171+
self._image_lengths.append(np.array([], dtype=np.int32).reshape(0, 2))
172+
173+
# Read padded image positions
174+
max_images_per_doc = self._n_images.max() if len(self._n_images) > 0 else 0
175+
if max_images_per_doc > 0:
176+
padded_positions = np.frombuffer(
158177
self._index_bin_buffer,
159178
dtype=np.int32,
160-
count=n_images * 2,
161-
offset=offset + self._n_images.nbytes + 2 * images_seen * np.dtype(np.int32).itemsize,
162-
).reshape(-1, 2)
163-
)
164-
self._num_pixels += self._image_lengths[-1].prod(axis=1, initial=3).sum()
165-
self._image_positions.append(
166-
np.frombuffer(
167-
self._index_bin_buffer,
168-
dtype=np.int32,
169-
count=n_images,
170-
offset=offset
171-
+ self._n_images.nbytes
172-
+ 2 * self._n_images.sum() * np.dtype(np.int32).itemsize
173-
+ images_seen * np.dtype(np.int32).itemsize,
174-
)
175-
)
176-
images_seen += n_images
179+
count=self._num_documents * max_images_per_doc,
180+
offset=offset,
181+
).reshape(self._num_documents, max_images_per_doc)
182+
183+
# Filter out padding (-1 values) to get actual positions
184+
self._image_positions = []
185+
for doc_idx, n_images in enumerate(self._n_images):
186+
if n_images > 0:
187+
actual_positions = padded_positions[doc_idx][:n_images]
188+
# Remove any -1 padding that might exist
189+
actual_positions = actual_positions[actual_positions != -1]
190+
self._image_positions.append(actual_positions)
191+
else:
192+
self._image_positions.append(np.array([], dtype=np.int32))
193+
else:
194+
self._image_positions = [np.array([], dtype=np.int32) for _ in range(self._num_documents)]
195+
else:
196+
self._image_lengths = [np.array([], dtype=np.int32).reshape(0, 2) for _ in range(self._num_documents)]
197+
self._image_positions = [np.array([], dtype=np.int32) for _ in range(self._num_documents)]
177198

178199
self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C")
179200
self._bin_buffer = memoryview(self._bin_buffer_mmap)
@@ -217,18 +238,29 @@ def get(
217238
if self._has_images:
218239
image_positions = self._image_positions[idx]
219240
# Truncations with images are not yet supported, so we get all images from the document
220-
pixels = np.frombuffer(
221-
self._bin_buffer,
222-
dtype=np.dtype(np.uint8),
223-
count=self._image_lengths[idx].prod(initial=3),
224-
offset=self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize,
225-
)
226-
images = []
227-
start = 0
228-
for image_length in self._image_lengths[idx]:
229-
n_pixels = image_length.prod(initial=3)
230-
images.append(pixels[start : start + n_pixels].reshape(3, image_length[0], image_length[1]))
231-
start += n_pixels
241+
if len(self._image_lengths[idx]) > 0:
242+
total_pixels_needed = sum(
243+
length[0] * length[1] * 3 for length in self._image_lengths[idx]
244+
)
245+
246+
pixels = np.frombuffer(
247+
self._bin_buffer,
248+
dtype=np.dtype(np.uint8),
249+
count=total_pixels_needed,
250+
offset=self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize,
251+
)
252+
253+
images = []
254+
start = 0
255+
for image_length in self._image_lengths[idx]:
256+
height, width = image_length[0], image_length[1]
257+
n_pixels = height * width * 3
258+
image_data = pixels[start : start + n_pixels].reshape(3, height, width)
259+
images.append(image_data)
260+
start += n_pixels
261+
else:
262+
images = []
263+
232264
sample_spans = None
233265
if use_loss_masking_spans and self._spans is not None:
234266
sample_spans = self._spans[idx]
@@ -358,6 +390,9 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
358390
bin_stream.write(pixels.tobytes(order="C"))
359391
total_im_size += pixels.size
360392
im_positions.append(document.image_positions)
393+
else:
394+
n_images.append(0)
395+
im_positions.append([])
361396

362397
# Update metadata
363398
doc_length = len(document.token_ids)
@@ -387,7 +422,14 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
387422
if total_images:
388423
n_images = np.array(n_images, dtype=np.int32)
389424
image_lengths = np.stack(image_lengths, dtype=np.int32)
390-
im_positions = np.array(im_positions, dtype=np.int32)
425+
426+
# Pad im_positions to make them equal length
427+
max_images = max(len(pos_list) for pos_list in im_positions)
428+
padded_im_positions = []
429+
for pos_list in im_positions:
430+
padded_pos = pos_list + [-1] * (max_images - len(pos_list))
431+
padded_im_positions.append(padded_pos)
432+
im_positions = np.array(padded_im_positions, dtype=np.int32)
391433
else:
392434
n_images = np.array([])
393435
image_lengths = np.array([])

fast_llm/data/dataset/gpt/sampled.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ def __getitem__(self, index: int) -> typing.Any:
549549
use_loss_masking_spans=self._parameters.use_loss_masking_spans,
550550
)
551551
start_pos = 0
552-
if sample.image_positions:
552+
if len(sample.image_positions) > 0:
553553
for idx, im_position in enumerate(sample.image_positions):
554554
# image_positions.append(im_positions + len(token_ids) + image_tokens_added)
555555
# Add placeholders for image tokens

fast_llm/data/preparator/gpt_memmap/prepare.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,13 +158,13 @@ def _document_generator():
158158
for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"):
159159
yield GPTSample(
160160
np.array(item["input_ids"], dtype=self._data_type.numpy),
161+
item["images"] if self._config.dataset.images else None,
162+
item["image_positions"] if self._config.dataset.image_positions else None,
161163
(
162164
np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2)
163165
if self._config.dataset.loss_masking_spans
164166
else None
165167
),
166-
item["images"] if self._config.dataset.images else None,
167-
item["image_positions"] if self._config.dataset.image_positions else None,
168168
item.get("chosen_token_spans", None),
169169
item.get("rejected_token_spans", None),
170170
)

0 commit comments

Comments
 (0)