-
Notifications
You must be signed in to change notification settings - Fork 80
Open
Labels
bugSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is neededwaiting on authorWaiting for user input or feedback.Waiting for user input or feedback.
Description
🐛 Bug
Segmentation fault happens when streaming TokensLoaders with zstd compression. I'd like to know what I should check in order to find a fix for this.
I suspect this might be related to past issues like #459 ?
To Reproduce
I build the DCLM 1.0 dataset with this:
import json
import time
from functools import partial
from pathlib import Path
import zstandard as zstd
from litdata import TokensLoader, optimize
from litgpt.tokenizer import Tokenizer
from litgpt.utils import CLI, extend_checkpoint_dir
def get_files(input_dir, val_set):
candidate_dirs = []
# Note that global shard uses 01-10, local shard uses 0-9.
# Val set use the global-shard_10_of_10/local-shard_9_of_10/**
if val_set:
candidate_dirs.append(
Path(input_dir) / "global-shard_10_of_10" / "local-shard_9_of_10"
)
else:
for global_shard in range(1, 11):
for local_shard in range(10):
if global_shard == 10 and local_shard == 9:
continue
candidate_dirs.append(
Path(input_dir)
/ f"global-shard_{global_shard:02}_of_10"
/ f"local-shard_{local_shard}_of_10"
)
files = []
for d in candidate_dirs:
files.extend(list(Path(d).rglob("*.zst")))
return [str(file) for file in files]
def process_items(filepath, tokenizer):
with zstd.open(open(filepath, "rb"), "rt", encoding="utf-8") as f:
for row in f:
text = json.loads(row)["text"]
text_ids = tokenizer.encode(string=text, bos=False, eos=True)
yield text_ids
def prepare(
input_dir: Path = None,
output_dir: Path = None,
tokenizer_path: Path = None,
chunk_size: int = (2049 * 16384),
fast_dev_run: bool = False,
split: str = "train",
workers: int = 1,
) -> None:
tokenizer_path = extend_checkpoint_dir(tokenizer_path)
tokenizer = Tokenizer(tokenizer_path)
filelist = get_files(input_dir, val_set=split == "val")
start_time = time.time()
optimize(
partial(process_items, tokenizer=tokenizer),
inputs=filelist,
output_dir=output_dir,
chunk_size=chunk_size,
num_workers=workers,
fast_dev_run=fast_dev_run,
item_loader=TokensLoader(),
compression="zstd",
keep_data_ordered=False,
mode="overwrite",
)
elapsed_time = time.time() - start_time
print(f"Time taken: {elapsed_time:.2f} seconds")
if __name__ == "__main__":
CLI(prepare)Then stream with this:
from tqdm import tqdm
from litdata import StreamingDataset, StreamingDataLoader, TokensLoader
ds = StreamingDataset(
"/gpfs/data/oermannlab/project_data/weaver/pretrain_data/dclm/train",
item_loader=TokensLoader(1025),
)
dl = StreamingDataLoader(ds, batch_size=4, num_workers=0)
for _ in tqdm(dl):
passSame error happens for different combinations of block size, batch size and worker numbers. Directly iterating over the dataset seems to have no problem.
Expected behavior
Normal streaming should work.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is neededwaiting on authorWaiting for user input or feedback.Waiting for user input or feedback.