-
Notifications
You must be signed in to change notification settings - Fork 3.3k
[speechm2] Support indexed sharegpt JSONL and webdataset formats #15410
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -47,6 +47,7 @@ | |
| NeMoMultimodalConversation, | ||
| NeMoMultimodalConversationJsonlAdapter, | ||
| NeMoMultimodalConversationShareGPTJsonlAdapter, | ||
| NeMoMultimodalConversationShareGPTWebdatasetAdapter, | ||
| NeMoSFTJsonlAdapter, | ||
| TextTurn, | ||
| ) | ||
|
|
@@ -324,6 +325,7 @@ def read_share_gpt_as_conversation(config) -> tuple[CutSet, bool]: | |
| tarred_audio_filepaths=config.get("tarred_audio_filepaths"), | ||
| audio_locator_tag=config.audio_locator_tag, | ||
| audio_placeholders=config.audio_placeholders, | ||
| audio_root=config.get("audio_root"), | ||
| token_equivalent_duration=config.get("token_equivalent_duration"), | ||
| shuffle_shards=config.shuffle, | ||
| shard_seed=config.shard_seed, | ||
|
|
@@ -335,6 +337,24 @@ def read_share_gpt_as_conversation(config) -> tuple[CutSet, bool]: | |
| return cuts, True | ||
|
|
||
|
|
||
| @data_type_parser(["share_gpt_webdataset"]) | ||
| def read_share_gpt_webdataset_as_conversation(config) -> tuple[CutSet, bool]: | ||
| """Read ShareGPT conversations from WebDataset tar archives.""" | ||
| cuts = CutSet( | ||
| NeMoMultimodalConversationShareGPTWebdatasetAdapter( | ||
| data_dir=config.data_dir, | ||
| audio_locator_tag=config.audio_locator_tag, | ||
| audio_placeholders=config.get("audio_placeholders"), | ||
| token_equivalent_duration=config.get("token_equivalent_duration"), | ||
| shuffle_shards=config.shuffle, | ||
| shard_seed=config.shard_seed, | ||
| ) | ||
| ) | ||
| if not config.get("force_finite", False): | ||
| cuts = cuts.repeat(preserve_id=True) | ||
| return cuts, True | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where do we need this bool for compatibility? curious if the same function can get achieved by just checking config |
||
|
|
||
|
|
||
| def _resolve_shar_inputs(path: Union[str, Path], only_metadata: bool) -> dict: | ||
| if only_metadata: | ||
| return dict(fields={"cuts": sorted(Path(path).glob("cuts.*"))}) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,287 @@ | ||
| # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import json | ||
| import os | ||
| import random | ||
| import struct | ||
| import tarfile | ||
| from pathlib import Path | ||
|
|
||
| import numpy as np | ||
|
|
||
|
|
||
| class LazyShuffledRange: | ||
| """ | ||
| Generates a permutation of ``range(n)`` lazily using a Feistel cipher, | ||
| without materializing the full index list. Each element is computed on | ||
| the fly in O(1) time and the object itself uses O(1) memory regardless | ||
| of ``n``. | ||
|
|
||
| The technique is known as *cycle-walking* format-preserving encryption: | ||
| a Feistel network is a bijection on ``[0, 2^k)``, and repeatedly applying | ||
| it until the output falls within ``[0, n)`` restricts it to a bijection | ||
| on the desired domain. | ||
|
|
||
| Args: | ||
| n: Size of the range to permute. | ||
| rng: A ``random.Random`` instance used to derive round keys. | ||
| """ | ||
|
|
||
| def __init__(self, n: int, rng: random.Random): | ||
| self.n = n | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rename self.n to like self.width or another semantically meaningful name. |
||
| if n <= 1: | ||
| return | ||
| bits = (n - 1).bit_length() | ||
| if bits < 2: | ||
| bits = 2 | ||
| if bits % 2: | ||
| bits += 1 | ||
| self._half = bits // 2 | ||
| self._mask = (1 << self._half) - 1 | ||
| self._rounds = 6 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make a arg in init. |
||
| self._keys = [rng.getrandbits(64) for _ in range(self._rounds)] | ||
|
|
||
| def _permute_one(self, x: int) -> int: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there no numpy methods you can crib for this? or is the native Cython bitshift for efficient? |
||
| left = (x >> self._half) & self._mask | ||
| right = x & self._mask | ||
| for key in self._keys: | ||
| left, right = right, left ^ (((right * 2654435761) ^ key) >> 32 & self._mask) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make global var at top of file. |
||
| return (left << self._half) | right | ||
|
|
||
| def __len__(self) -> int: | ||
| return self.n | ||
|
|
||
| def __iter__(self): | ||
| n = self.n | ||
| if n <= 0: | ||
| return | ||
| if n == 1: | ||
| yield 0 | ||
| return | ||
| for i in range(n): | ||
| x = i | ||
| while True: | ||
| x = self._permute_one(x) | ||
| if x < n: | ||
| yield x | ||
| break | ||
|
|
||
|
|
||
| def _load_index(data_path: str, idx_path: str | None = None): | ||
| """ | ||
| Load a memmap'd offset index for *data_path*. | ||
| Returns ``(offsets_memmap, num_samples, data_file_size)``. | ||
| Handles the optional sentinel (last entry == file size). | ||
| """ | ||
| if idx_path is None: | ||
| idx_path = data_path + '.idx' | ||
| assert os.path.exists(data_path), f"Data file not found: {data_path}" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pretty sure both these assertations will rise from the following calls no? |
||
| assert os.path.exists(idx_path), f"Index file not found: {idx_path}" | ||
| offsets = np.memmap(idx_path, dtype=np.dtype('<u8'), mode='r') | ||
| data_size = os.path.getsize(data_path) | ||
| length = offsets.shape[0] - 1 if offsets[-1] == data_size else offsets.shape[0] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. strikes me as clunky keeping an extra arg just for a potential off by one error. Where is this critical and can we move it outside thisfunction? |
||
| return offsets, length, data_size | ||
|
|
||
|
|
||
| def _resolve_idx(idx: int, length: int) -> int: | ||
| if idx < 0: | ||
| idx += length | ||
| if idx < 0 or idx >= length: | ||
| raise IndexError("Index out of bounds") | ||
| return idx | ||
|
|
||
|
|
||
| class IndexedJSONLReader: | ||
| def __init__(self, jsonl_path: Path | str, idx_path: Path | str | None = None): | ||
| self.data_path = str(jsonl_path) | ||
| self.offsets, self._len, self._data_size = _load_index(self.data_path, str(idx_path) if idx_path else None) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh yeah, just make the length aspect part of the init. makes the _load_index more conceptually clear and wont' add clutter to the two inits |
||
|
|
||
| def __len__(self): | ||
| return self._len | ||
|
|
||
| def __getitem__(self, idx): | ||
| idx = _resolve_idx(idx, self._len) | ||
| start = int(self.offsets[idx]) | ||
| end = int(self.offsets[idx + 1]) if idx + 1 < self.offsets.shape[0] else self._data_size | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isn't this just reversing the logic for _load_index length arg? |
||
| with open(self.data_path, 'rb') as f: | ||
| f.seek(start) | ||
| data = f.read(end - start) | ||
| return json.loads(data.decode('utf-8')) | ||
|
|
||
|
|
||
| def _split_json_audio_pair(name_a, bytes_a, name_b, bytes_b): | ||
| """Classify two tar members into (json_data_dict, audio_bytes, audio_name) regardless of order.""" | ||
| if name_a.endswith('.json'): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. any strong thoughts on just making a data class to manage this tuple? (i'm imagining the awkardness of accidentally indexing the audio_bytes when you just wanted the name) |
||
| return json.loads(bytes_a), bytes_b, name_b | ||
| if name_b.endswith('.json'): | ||
| return json.loads(bytes_b), bytes_a, name_a | ||
| raise ValueError(f"Expected one .json member in tar sample pair, got: {name_a}, {name_b}") | ||
|
|
||
|
|
||
| class IndexedTarSampleReader: | ||
| """ | ||
| Random access to WebDataset tar samples (``N.json`` + ``N.<audio>``) via an index file. | ||
| Index format is identical to ``IndexedJSONLReader``: little-endian uint64 offsets, | ||
| optionally followed by a sentinel equal to the tar file size. | ||
| """ | ||
|
|
||
| def __init__(self, tar_path: str | Path, idx_path: str | Path | None = None): | ||
| self.data_path = str(tar_path) | ||
| self.offsets, self._len, self._data_size = _load_index(self.data_path, str(idx_path) if idx_path else None) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just noting the above that you can move the length logic from _load_index to here to make the function more semantically shallow |
||
| self._validate_index() | ||
|
|
||
| def _validate_index(self): | ||
| if self._len == 0: | ||
| return | ||
| max_offset = int(max(self.offsets[i] for i in range(self._len))) | ||
| if max_offset >= self._data_size: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All of this information is coming from _load_index. Shouldn't the check be there as well? Remove the burden from the class what should be managed by the function |
||
| raise ValueError( | ||
| f"Tar index for {self.data_path} contains offset {max_offset} " | ||
| f"beyond file size {self._data_size}. " | ||
| f"The .idx file may have been created by an incompatible tool " | ||
| f"or for a different file." | ||
| ) | ||
| # Validate first offset is a valid tar header. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Likewise think this part of the check can come from _load_offset. But also makes sense to only add the tar check in a tar related class so more ambivalent here. |
||
| self._check_offset_is_tar_header(int(self.offsets[0]), label="first") | ||
| # Strip trailing sentinels: some tools store the offset of the | ||
| # end-of-archive zero-block marker as a sentinel instead of the | ||
| # file size (which _load_index already handles). | ||
| while self._len > 0: | ||
| last = int(self.offsets[self._len - 1]) | ||
| with open(self.data_path, 'rb') as f: | ||
| f.seek(last) | ||
| buf = f.read(512) | ||
| if len(buf) < 512 or buf == b'\0' * 512: | ||
| self._len -= 1 | ||
| else: | ||
| break | ||
|
|
||
| def _check_offset_is_tar_header(self, offset: int, label: str = ""): | ||
| with open(self.data_path, 'rb') as f: | ||
| f.seek(offset) | ||
| buf = f.read(512) | ||
| if len(buf) < 512: | ||
| raise ValueError( | ||
| f"Tar index for {self.data_path}: {label} offset {offset} " | ||
| f"is too close to EOF (file size {self._data_size})." | ||
| ) | ||
| if buf == b'\0' * 512: | ||
| raise ValueError( | ||
| f"Tar index for {self.data_path}: {label} offset {offset} " | ||
| f"points to a zero block (end-of-archive marker), not a tar header. " | ||
| f"The .idx file may have been created by an incompatible tool " | ||
| f"or for a different file." | ||
| ) | ||
| try: | ||
| tarfile.TarInfo.frombuf(buf, tarfile.ENCODING, "surrogateescape") | ||
| except tarfile.TarError as e: | ||
| raise ValueError( | ||
| f"Tar index for {self.data_path}: {label} offset {offset} " | ||
| f"does not point to a valid tar header: {e}. " | ||
| f"The .idx file may have been created by an incompatible tool " | ||
| f"(e.g. has a binary header or stores per-member offsets) " | ||
| f"or for a different file." | ||
| ) from e | ||
|
|
||
| def __len__(self): | ||
| return self._len | ||
|
|
||
| def __getitem__(self, idx): | ||
| idx = _resolve_idx(idx, self._len) | ||
| offset = int(self.offsets[idx]) | ||
| with open(self.data_path, 'rb') as f: | ||
| f.seek(offset) | ||
| try: | ||
| name_a, bytes_a = _read_tar_member(f) | ||
| except (EOFError, tarfile.TarError) as e: | ||
| raise type(e)( | ||
| f"{e} — reading first member of sample {idx}/{self._len} " | ||
| f"at offset {offset} in {self.data_path} " | ||
| f"(file size {self._data_size})" | ||
| ) from e | ||
| try: | ||
| name_b, bytes_b = _read_tar_member(f) | ||
| except (EOFError, tarfile.TarError) as e: | ||
| raise type(e)( | ||
| f"{e} — reading second member of sample {idx}/{self._len} " | ||
| f"(first member was '{name_a}', {len(bytes_a)} bytes) " | ||
| f"at offset {offset} in {self.data_path} " | ||
| f"(file size {self._data_size})" | ||
| ) from e | ||
| return _split_json_audio_pair(name_a, bytes_a, name_b, bytes_b) | ||
|
|
||
|
|
||
| def _read_tar_member(f): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what features is needed here outside python native tar library? memory serves they involve a lot of this under the hood anyhow |
||
| """Read the next regular-file tar member, skipping non-regular entries | ||
| (PAX headers, GNU long-name headers, directory entries, etc.).""" | ||
| while True: | ||
| header_buf = f.read(512) | ||
| if len(header_buf) < 512 or header_buf == b'\0' * 512: | ||
| raise EOFError("End of tar archive or unexpected EOF") | ||
| info = tarfile.TarInfo.frombuf(header_buf, tarfile.ENCODING, "surrogateescape") | ||
| data = f.read(info.size) | ||
| if len(data) < info.size: | ||
| raise EOFError("Unexpected end of tar file while reading data") | ||
| remainder = info.size % 512 | ||
| if remainder: | ||
| f.seek(512 - remainder, 1) | ||
| if info.type not in (tarfile.REGTYPE, tarfile.AREGTYPE): | ||
| continue | ||
| return info.name, data | ||
|
|
||
|
|
||
| def create_index(jsonl_path, idx_path): | ||
| """ | ||
| Creates a raw binary index file compatible with Megatron-Energon (CrudeJsonlDataset). | ||
|
|
||
| Format: sequence of little-endian uint64 values | ||
| ``[Offset_0, Offset_1, ..., Offset_N, File_Size]`` | ||
| """ | ||
| with open(jsonl_path, 'rb') as f_in, open(idx_path, 'wb') as f_out: | ||
| current_offset = 0 | ||
| write_buffer = bytearray() | ||
| write_buffer.extend(struct.pack('<Q', current_offset)) | ||
| for line in f_in: | ||
| current_offset += len(line) | ||
| write_buffer.extend(struct.pack('<Q', current_offset)) | ||
| if len(write_buffer) > 8 * 1024 * 1024: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. very nitpicky, but just write out the full multiplication as a var above and comment. no need to do the extra ops for every line. |
||
| f_out.write(write_buffer) | ||
| write_buffer.clear() | ||
| if write_buffer: | ||
| f_out.write(write_buffer) | ||
|
|
||
|
|
||
| def create_tar_index(tar_path, idx_path): | ||
| """ | ||
| Creates a raw binary index file for a WebDataset tar archive. | ||
| Stores the byte offset of the first member of each sample (grouped by basename), | ||
| followed by a sentinel equal to the tar file size. | ||
| Format is identical to :func:`create_index`. | ||
| """ | ||
| offsets = [] | ||
| prev_stem = None | ||
| with tarfile.open(tar_path, 'r:') as tar: | ||
| for member in tar: | ||
| if not member.isreg(): | ||
| continue | ||
| stem = Path(member.name).stem | ||
| if stem != prev_stem: | ||
| offsets.append(member.offset) | ||
| prev_stem = stem | ||
| with open(idx_path, 'wb') as f: | ||
| buf = bytearray() | ||
| for off in offsets: | ||
| buf.extend(struct.pack('<Q', off)) | ||
| buf.extend(struct.pack('<Q', os.path.getsize(tar_path))) | ||
| f.write(buf) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add quick comment for this flag (when going through dataloaders, we have a bit of a depth issue where the purpose of flags can get sidetracked.)