@@ -46,7 +46,7 @@ def _init(
46
46
self ._has_spans = 0
47
47
self ._has_images = 0
48
48
self ._has_preference_spans = False
49
-
49
+
50
50
with self ._prefix .with_suffix (".idx" ).open ("rb" ) as stream :
51
51
Assert .eq (stream .read (9 ), MEMMAP_INDEX_HEADER , msg = f"File: { stream .name } " )
52
52
self ._version = struct .unpack ("<Q" , stream .read (8 ))[0 ]
@@ -55,14 +55,12 @@ def _init(
55
55
self ._has_spans = struct .unpack ("<B" , stream .read (1 ))[0 ]
56
56
if self ._version >= 3 :
57
57
self ._has_preference_spans = struct .unpack ("<B" , stream .read (1 ))[0 ]
58
-
59
- if self ._version >= 3 :
60
- self ._has_preference_spans = struct .unpack ("<B" , stream .read (1 ))[0 ]
61
-
62
58
if self ._version >= 4 :
63
59
self ._has_images = struct .unpack ("<B" , stream .read (1 ))[0 ]
60
+ # not sure of assignment, but has to read something here
61
+ self ._has_preference_spans = struct .unpack ("<B" , stream .read (1 ))[0 ]
64
62
65
- self ._dtype = MEMMAP_DTYPES [struct .unpack ("<B" , stream .read (1 ))[0 ]].numpy
63
+ self ._dtype = MEMMAP_DTYPES [struct .unpack ("<B" , stream .read (1 ))[0 ]].numpy
66
64
self ._num_documents = struct .unpack ("<Q" , stream .read (8 ))[0 ]
67
65
_ = struct .unpack ("<Q" , stream .read (8 ))[0 ]
68
66
offset = stream .tell ()
@@ -112,8 +110,7 @@ def _init(
112
110
offset += (
113
111
self ._num_spans .nbytes
114
112
+ self ._num_spans .sum () * 2 * np .dtype (np .int32 ).itemsize
115
- + sum ([x .nbytes for x in self ._spans ])
116
- )
113
+ )
117
114
# read preference spans
118
115
self ._chosen_spans = None
119
116
self ._rejected_spans = None
@@ -146,34 +143,58 @@ def _init(
146
143
self ._image_lengths = None
147
144
self ._image_positions = None
148
145
if self ._has_images and self ._version >= 4 :
146
+ # Read number of images per document
149
147
self ._n_images = np .frombuffer (
150
148
self ._index_bin_buffer , dtype = np .int32 , count = self ._num_documents , offset = offset
151
149
)
152
- self ._image_lengths = []
153
- self ._image_positions = []
154
- images_seen = 0
155
- for n_images in self ._n_images :
156
- self ._image_lengths .append (
157
- np .frombuffer (
150
+ offset += self ._n_images .nbytes
151
+ # Read image dimensions
152
+ total_images = self ._n_images .sum ()
153
+ if total_images > 0 :
154
+ image_lengths_flat = np .frombuffer (
155
+ self ._index_bin_buffer ,
156
+ dtype = np .int32 ,
157
+ count = total_images * 2 ,
158
+ offset = offset
159
+ ).reshape (- 1 , 2 )
160
+ offset += image_lengths_flat .nbytes
161
+
162
+ # Split image lengths by document
163
+ self ._image_lengths = []
164
+ img_start = 0
165
+ for n_images in self ._n_images :
166
+ if n_images > 0 :
167
+ self ._image_lengths .append (image_lengths_flat [img_start :img_start + n_images ])
168
+ self ._num_pixels += self ._image_lengths [- 1 ].prod (axis = 1 , initial = 3 ).sum ()
169
+ img_start += n_images
170
+ else :
171
+ self ._image_lengths .append (np .array ([], dtype = np .int32 ).reshape (0 , 2 ))
172
+
173
+ # Read padded image positions
174
+ max_images_per_doc = self ._n_images .max () if len (self ._n_images ) > 0 else 0
175
+ if max_images_per_doc > 0 :
176
+ padded_positions = np .frombuffer (
158
177
self ._index_bin_buffer ,
159
178
dtype = np .int32 ,
160
- count = n_images * 2 ,
161
- offset = offset + self ._n_images .nbytes + 2 * images_seen * np .dtype (np .int32 ).itemsize ,
162
- ).reshape (- 1 , 2 )
163
- )
164
- self ._num_pixels += self ._image_lengths [- 1 ].prod (axis = 1 , initial = 3 ).sum ()
165
- self ._image_positions .append (
166
- np .frombuffer (
167
- self ._index_bin_buffer ,
168
- dtype = np .int32 ,
169
- count = n_images ,
170
- offset = offset
171
- + self ._n_images .nbytes
172
- + 2 * self ._n_images .sum () * np .dtype (np .int32 ).itemsize
173
- + images_seen * np .dtype (np .int32 ).itemsize ,
174
- )
175
- )
176
- images_seen += n_images
179
+ count = self ._num_documents * max_images_per_doc ,
180
+ offset = offset ,
181
+ ).reshape (self ._num_documents , max_images_per_doc )
182
+
183
+ # Filter out padding (-1 values) to get actual positions
184
+ self ._image_positions = []
185
+ for doc_idx , n_images in enumerate (self ._n_images ):
186
+ if n_images > 0 :
187
+ actual_positions = padded_positions [doc_idx ][:n_images ]
188
+ # Remove any -1 padding that might exist
189
+ actual_positions = actual_positions [actual_positions != - 1 ]
190
+ self ._image_positions .append (actual_positions )
191
+ else :
192
+ self ._image_positions .append (np .array ([], dtype = np .int32 ))
193
+ else :
194
+ self ._image_positions = [np .array ([], dtype = np .int32 ) for _ in range (self ._num_documents )]
195
+ else :
196
+ self ._image_lengths = [np .array ([], dtype = np .int32 ).reshape (0 , 2 ) for _ in range (self ._num_documents )]
197
+ self ._image_positions = [np .array ([], dtype = np .int32 ) for _ in range (self ._num_documents )]
177
198
178
199
self ._bin_buffer_mmap = np .memmap (self ._prefix .with_suffix (".bin" ), mode = "r" , order = "C" )
179
200
self ._bin_buffer = memoryview (self ._bin_buffer_mmap )
@@ -217,18 +238,29 @@ def get(
217
238
if self ._has_images :
218
239
image_positions = self ._image_positions [idx ]
219
240
# Truncations with images are not yet supported, so we get all images from the document
220
- pixels = np .frombuffer (
221
- self ._bin_buffer ,
222
- dtype = np .dtype (np .uint8 ),
223
- count = self ._image_lengths [idx ].prod (initial = 3 ),
224
- offset = self ._pointers [idx ] + self ._document_sizes [idx ] * np .dtype (self ._dtype ).itemsize ,
225
- )
226
- images = []
227
- start = 0
228
- for image_length in self ._image_lengths [idx ]:
229
- n_pixels = image_length .prod (initial = 3 )
230
- images .append (pixels [start : start + n_pixels ].reshape (3 , image_length [0 ], image_length [1 ]))
231
- start += n_pixels
241
+ if len (self ._image_lengths [idx ]) > 0 :
242
+ total_pixels_needed = sum (
243
+ length [0 ] * length [1 ] * 3 for length in self ._image_lengths [idx ]
244
+ )
245
+
246
+ pixels = np .frombuffer (
247
+ self ._bin_buffer ,
248
+ dtype = np .dtype (np .uint8 ),
249
+ count = total_pixels_needed ,
250
+ offset = self ._pointers [idx ] + self ._document_sizes [idx ] * np .dtype (self ._dtype ).itemsize ,
251
+ )
252
+
253
+ images = []
254
+ start = 0
255
+ for image_length in self ._image_lengths [idx ]:
256
+ height , width = image_length [0 ], image_length [1 ]
257
+ n_pixels = height * width * 3
258
+ image_data = pixels [start : start + n_pixels ].reshape (3 , height , width )
259
+ images .append (image_data )
260
+ start += n_pixels
261
+ else :
262
+ images = []
263
+
232
264
sample_spans = None
233
265
if use_loss_masking_spans and self ._spans is not None :
234
266
sample_spans = self ._spans [idx ]
@@ -358,6 +390,9 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
358
390
bin_stream .write (pixels .tobytes (order = "C" ))
359
391
total_im_size += pixels .size
360
392
im_positions .append (document .image_positions )
393
+ else :
394
+ n_images .append (0 )
395
+ im_positions .append ([])
361
396
362
397
# Update metadata
363
398
doc_length = len (document .token_ids )
@@ -387,7 +422,14 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
387
422
if total_images :
388
423
n_images = np .array (n_images , dtype = np .int32 )
389
424
image_lengths = np .stack (image_lengths , dtype = np .int32 )
390
- im_positions = np .array (im_positions , dtype = np .int32 )
425
+
426
+ # Pad im_positions to make them equal length
427
+ max_images = max (len (pos_list ) for pos_list in im_positions )
428
+ padded_im_positions = []
429
+ for pos_list in im_positions :
430
+ padded_pos = pos_list + [- 1 ] * (max_images - len (pos_list ))
431
+ padded_im_positions .append (padded_pos )
432
+ im_positions = np .array (padded_im_positions , dtype = np .int32 )
391
433
else :
392
434
n_images = np .array ([])
393
435
image_lengths = np .array ([])
0 commit comments