Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions nemo/collections/common/data/lhotse/cutset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
NeMoMultimodalConversation,
NeMoMultimodalConversationJsonlAdapter,
NeMoMultimodalConversationShareGPTJsonlAdapter,
NeMoMultimodalConversationShareGPTWebdatasetAdapter,
NeMoSFTJsonlAdapter,
TextTurn,
)
Expand Down Expand Up @@ -324,6 +325,7 @@ def read_share_gpt_as_conversation(config) -> tuple[CutSet, bool]:
tarred_audio_filepaths=config.get("tarred_audio_filepaths"),
audio_locator_tag=config.audio_locator_tag,
audio_placeholders=config.audio_placeholders,
audio_root=config.get("audio_root"),
token_equivalent_duration=config.get("token_equivalent_duration"),
shuffle_shards=config.shuffle,
shard_seed=config.shard_seed,
Expand All @@ -335,6 +337,24 @@ def read_share_gpt_as_conversation(config) -> tuple[CutSet, bool]:
return cuts, True


@data_type_parser(["share_gpt_webdataset"])
def read_share_gpt_webdataset_as_conversation(config) -> tuple[CutSet, bool]:
"""Read ShareGPT conversations from WebDataset tar archives."""
cuts = CutSet(
NeMoMultimodalConversationShareGPTWebdatasetAdapter(
data_dir=config.data_dir,
audio_locator_tag=config.audio_locator_tag,
audio_placeholders=config.get("audio_placeholders"),
token_equivalent_duration=config.get("token_equivalent_duration"),
shuffle_shards=config.shuffle,
shard_seed=config.shard_seed,
)
)
if not config.get("force_finite", False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add quick comment for this flag (when going through dataloaders, we have a bit of a depth issue where the purpose of flags can get sidetracked.)

cuts = cuts.repeat(preserve_id=True)
return cuts, True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where do we need this bool for compatibility? curious if the same function can get achieved by just checking config



def _resolve_shar_inputs(path: Union[str, Path], only_metadata: bool) -> dict:
if only_metadata:
return dict(fields={"cuts": sorted(Path(path).glob("cuts.*"))})
Expand Down
287 changes: 287 additions & 0 deletions nemo/collections/common/data/lhotse/indexed_adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import random
import struct
import tarfile
from pathlib import Path

import numpy as np


class LazyShuffledRange:
"""
Generates a permutation of ``range(n)`` lazily using a Feistel cipher,
without materializing the full index list. Each element is computed on
the fly in O(1) time and the object itself uses O(1) memory regardless
of ``n``.

The technique is known as *cycle-walking* format-preserving encryption:
a Feistel network is a bijection on ``[0, 2^k)``, and repeatedly applying
it until the output falls within ``[0, n)`` restricts it to a bijection
on the desired domain.

Args:
n: Size of the range to permute.
rng: A ``random.Random`` instance used to derive round keys.
"""

def __init__(self, n: int, rng: random.Random):
self.n = n
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename self.n to like self.width or another semantically meaningful name.

if n <= 1:
return
bits = (n - 1).bit_length()
if bits < 2:
bits = 2
if bits % 2:
bits += 1
self._half = bits // 2
self._mask = (1 << self._half) - 1
self._rounds = 6
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make a arg in init.

self._keys = [rng.getrandbits(64) for _ in range(self._rounds)]

def _permute_one(self, x: int) -> int:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there no numpy methods you can crib for this? or is the native Cython bitshift for efficient?

left = (x >> self._half) & self._mask
right = x & self._mask
for key in self._keys:
left, right = right, left ^ (((right * 2654435761) ^ key) >> 32 & self._mask)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make global var at top of file.

return (left << self._half) | right

def __len__(self) -> int:
return self.n

def __iter__(self):
n = self.n
if n <= 0:
return
if n == 1:
yield 0
return
for i in range(n):
x = i
while True:
x = self._permute_one(x)
if x < n:
yield x
break


def _load_index(data_path: str, idx_path: str | None = None):
"""
Load a memmap'd offset index for *data_path*.
Returns ``(offsets_memmap, num_samples, data_file_size)``.
Handles the optional sentinel (last entry == file size).
"""
if idx_path is None:
idx_path = data_path + '.idx'
assert os.path.exists(data_path), f"Data file not found: {data_path}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty sure both these assertations will rise from the following calls no?

assert os.path.exists(idx_path), f"Index file not found: {idx_path}"
offsets = np.memmap(idx_path, dtype=np.dtype('<u8'), mode='r')
data_size = os.path.getsize(data_path)
length = offsets.shape[0] - 1 if offsets[-1] == data_size else offsets.shape[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strikes me as clunky keeping an extra arg just for a potential off by one error. Where is this critical and can we move it outside thisfunction?

return offsets, length, data_size


def _resolve_idx(idx: int, length: int) -> int:
if idx < 0:
idx += length
if idx < 0 or idx >= length:
raise IndexError("Index out of bounds")
return idx


class IndexedJSONLReader:
def __init__(self, jsonl_path: Path | str, idx_path: Path | str | None = None):
self.data_path = str(jsonl_path)
self.offsets, self._len, self._data_size = _load_index(self.data_path, str(idx_path) if idx_path else None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah, just make the length aspect part of the init. makes the _load_index more conceptually clear and wont' add clutter to the two inits


def __len__(self):
return self._len

def __getitem__(self, idx):
idx = _resolve_idx(idx, self._len)
start = int(self.offsets[idx])
end = int(self.offsets[idx + 1]) if idx + 1 < self.offsets.shape[0] else self._data_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this just reversing the logic for _load_index length arg?

with open(self.data_path, 'rb') as f:
f.seek(start)
data = f.read(end - start)
return json.loads(data.decode('utf-8'))


def _split_json_audio_pair(name_a, bytes_a, name_b, bytes_b):
"""Classify two tar members into (json_data_dict, audio_bytes, audio_name) regardless of order."""
if name_a.endswith('.json'):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any strong thoughts on just making a data class to manage this tuple? (i'm imagining the awkardness of accidentally indexing the audio_bytes when you just wanted the name)

return json.loads(bytes_a), bytes_b, name_b
if name_b.endswith('.json'):
return json.loads(bytes_b), bytes_a, name_a
raise ValueError(f"Expected one .json member in tar sample pair, got: {name_a}, {name_b}")


class IndexedTarSampleReader:
"""
Random access to WebDataset tar samples (``N.json`` + ``N.<audio>``) via an index file.
Index format is identical to ``IndexedJSONLReader``: little-endian uint64 offsets,
optionally followed by a sentinel equal to the tar file size.
"""

def __init__(self, tar_path: str | Path, idx_path: str | Path | None = None):
self.data_path = str(tar_path)
self.offsets, self._len, self._data_size = _load_index(self.data_path, str(idx_path) if idx_path else None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just noting the above that you can move the length logic from _load_index to here to make the function more semantically shallow

self._validate_index()

def _validate_index(self):
if self._len == 0:
return
max_offset = int(max(self.offsets[i] for i in range(self._len)))
if max_offset >= self._data_size:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of this information is coming from _load_index. Shouldn't the check be there as well? Remove the burden from the class what should be managed by the function

raise ValueError(
f"Tar index for {self.data_path} contains offset {max_offset} "
f"beyond file size {self._data_size}. "
f"The .idx file may have been created by an incompatible tool "
f"or for a different file."
)
# Validate first offset is a valid tar header.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise think this part of the check can come from _load_offset. But also makes sense to only add the tar check in a tar related class so more ambivalent here.

self._check_offset_is_tar_header(int(self.offsets[0]), label="first")
# Strip trailing sentinels: some tools store the offset of the
# end-of-archive zero-block marker as a sentinel instead of the
# file size (which _load_index already handles).
while self._len > 0:
last = int(self.offsets[self._len - 1])
with open(self.data_path, 'rb') as f:
f.seek(last)
buf = f.read(512)
if len(buf) < 512 or buf == b'\0' * 512:
self._len -= 1
else:
break

def _check_offset_is_tar_header(self, offset: int, label: str = ""):
with open(self.data_path, 'rb') as f:
f.seek(offset)
buf = f.read(512)
if len(buf) < 512:
raise ValueError(
f"Tar index for {self.data_path}: {label} offset {offset} "
f"is too close to EOF (file size {self._data_size})."
)
if buf == b'\0' * 512:
raise ValueError(
f"Tar index for {self.data_path}: {label} offset {offset} "
f"points to a zero block (end-of-archive marker), not a tar header. "
f"The .idx file may have been created by an incompatible tool "
f"or for a different file."
)
try:
tarfile.TarInfo.frombuf(buf, tarfile.ENCODING, "surrogateescape")
except tarfile.TarError as e:
raise ValueError(
f"Tar index for {self.data_path}: {label} offset {offset} "
f"does not point to a valid tar header: {e}. "
f"The .idx file may have been created by an incompatible tool "
f"(e.g. has a binary header or stores per-member offsets) "
f"or for a different file."
) from e

def __len__(self):
return self._len

def __getitem__(self, idx):
idx = _resolve_idx(idx, self._len)
offset = int(self.offsets[idx])
with open(self.data_path, 'rb') as f:
f.seek(offset)
try:
name_a, bytes_a = _read_tar_member(f)
except (EOFError, tarfile.TarError) as e:
raise type(e)(
f"{e} — reading first member of sample {idx}/{self._len} "
f"at offset {offset} in {self.data_path} "
f"(file size {self._data_size})"
) from e
try:
name_b, bytes_b = _read_tar_member(f)
except (EOFError, tarfile.TarError) as e:
raise type(e)(
f"{e} — reading second member of sample {idx}/{self._len} "
f"(first member was '{name_a}', {len(bytes_a)} bytes) "
f"at offset {offset} in {self.data_path} "
f"(file size {self._data_size})"
) from e
return _split_json_audio_pair(name_a, bytes_a, name_b, bytes_b)


def _read_tar_member(f):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what features is needed here outside python native tar library? memory serves they involve a lot of this under the hood anyhow

"""Read the next regular-file tar member, skipping non-regular entries
(PAX headers, GNU long-name headers, directory entries, etc.)."""
while True:
header_buf = f.read(512)
if len(header_buf) < 512 or header_buf == b'\0' * 512:
raise EOFError("End of tar archive or unexpected EOF")
info = tarfile.TarInfo.frombuf(header_buf, tarfile.ENCODING, "surrogateescape")
data = f.read(info.size)
if len(data) < info.size:
raise EOFError("Unexpected end of tar file while reading data")
remainder = info.size % 512
if remainder:
f.seek(512 - remainder, 1)
if info.type not in (tarfile.REGTYPE, tarfile.AREGTYPE):
continue
return info.name, data


def create_index(jsonl_path, idx_path):
"""
Creates a raw binary index file compatible with Megatron-Energon (CrudeJsonlDataset).

Format: sequence of little-endian uint64 values
``[Offset_0, Offset_1, ..., Offset_N, File_Size]``
"""
with open(jsonl_path, 'rb') as f_in, open(idx_path, 'wb') as f_out:
current_offset = 0
write_buffer = bytearray()
write_buffer.extend(struct.pack('<Q', current_offset))
for line in f_in:
current_offset += len(line)
write_buffer.extend(struct.pack('<Q', current_offset))
if len(write_buffer) > 8 * 1024 * 1024:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nitpicky, but just write out the full multiplication as a var above and comment. no need to do the extra ops for every line.

f_out.write(write_buffer)
write_buffer.clear()
if write_buffer:
f_out.write(write_buffer)


def create_tar_index(tar_path, idx_path):
"""
Creates a raw binary index file for a WebDataset tar archive.
Stores the byte offset of the first member of each sample (grouped by basename),
followed by a sentinel equal to the tar file size.
Format is identical to :func:`create_index`.
"""
offsets = []
prev_stem = None
with tarfile.open(tar_path, 'r:') as tar:
for member in tar:
if not member.isreg():
continue
stem = Path(member.name).stem
if stem != prev_stem:
offsets.append(member.offset)
prev_stem = stem
with open(idx_path, 'wb') as f:
buf = bytearray()
for off in offsets:
buf.extend(struct.pack('<Q', off))
buf.extend(struct.pack('<Q', os.path.getsize(tar_path)))
f.write(buf)
Loading
Loading