Skip to content

Commit

Permalink
tokenize_shuffle.py: tar files as input format (#129)
Browse files Browse the repository at this point in the history
* tokenize_shuffle.py: tar files as input format

* add docs + diff solution for pre-tokenized

* fix black

* add tarfile tests

* try various content_keys

* change key

* change path

* update token counts

---------

Co-authored-by: Maciej Kilian <[email protected]>
Co-authored-by: Maciej Kilian <[email protected]>
Co-authored-by: Maciej Kilian <[email protected]>
Co-authored-by: Maciej Kilian <[email protected]>
  • Loading branch information
5 people authored Dec 8, 2023
1 parent 1f985a5 commit 89e0824
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
38 changes: 37 additions & 1 deletion open_lm/datapreprocess/ray/tokenize_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class RawFileType(enum.Enum):
JSONL = 1
ZSTD_JSONL_COMPRESSED = 2
GZIP_JSONL_COMPRESSED = 3
TAR = 4
UNKNOWN = -1


Expand All @@ -108,13 +109,44 @@ def gzip_compressed_reader(fh: BinaryIO, content_key: str):
yield item[content_key]


def tar_reader(fh: BinaryIO, content_key: str):
"""
content_key: where in the tarfile to find the text/tokens. Options:
"txt" - read text file as string
"json:key" - read json[key] as string
"npy" - read numpy array as tokens
"""
content_ext = content_key.split(":")[0]
buffer = io.BytesIO(fh.read())
with tarfile.open(fileobj=buffer, mode="r") as tar:
samples = []
for member in tar.getmembers():
if member.isfile() and member.name.endswith(f".{content_ext}"):
with tar.extractfile(member) as fileobj:
if fileobj: # Ensure fileobj is not None
if content_ext == "txt":
content = fileobj.read().decode("utf-8")
elif content_ext == "json":
json_dict, json_key = json.load(fileobj), content_key.split(":")[1]
content = json_dict[json_key]
elif content_ext == "npy":
token_array = np.load(io.BytesIO(fileobj.read()), allow_pickle=True)
content = token_array.reshape(-1).tolist()
else:
raise ValueError(f"Unsupported content key extension: {content_key}")

yield content


def get_reader(file_type, content_key: str):
if file_type == RawFileType.JSONL:
return lambda x: jsonl_file_reader(x, content_key=content_key)
if file_type == RawFileType.ZSTD_JSONL_COMPRESSED:
return lambda x: zstd_compressed_reader(x, content_key=content_key)
if file_type == RawFileType.GZIP_JSONL_COMPRESSED:
return lambda x: gzip_compressed_reader(x, content_key=content_key)
if file_type == RawFileType.TAR:
return lambda x: tar_reader(x, content_key=content_key)
else:
raise Exception("Unsupported filetype")

Expand All @@ -126,6 +158,8 @@ def get_raw_filetype(key: str):
return RawFileType.ZSTD_JSONL_COMPRESSED
elif key.endswith(".jsonl.gz") or key.endswith(".json.gz"):
return RawFileType.GZIP_JSONL_COMPRESSED
elif key.endswith(".tar"):
return RawFileType.TAR
else:
logger.warning(f"Unknown filetype: {key}")
return RawFileType.UNKNOWN
Expand Down Expand Up @@ -367,6 +401,7 @@ def main(args):
parser.add_argument("--content_key", type=str, default="text")
parser.add_argument("--seqlen", type=int, default=2048)
parser.add_argument("--tokenizer", type=str, default="EleutherAI/gpt-neox-20b")
parser.add_argument("--vocab_size", type=int, default=None) # for pre-tokenized data, don't load tokenizer
parser.add_argument("--wds_chunk_size", type=int, default=8192)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--subset", type=int, default=None)
Expand Down Expand Up @@ -396,6 +431,7 @@ def main(args):
for inp_folder in input_folders:
input_paths += glob_files(inp_folder, suffix=".jsonl")
input_paths += glob_files(inp_folder, suffix=".zst")
input_paths += glob_files(inp_folder, suffix=".tar")
if args.subset is not None:
input_paths = input_paths[: args.subset]
rng = random.Random(args.seed)
Expand All @@ -416,7 +452,7 @@ def main(args):
ctx.execution_options.resource_limits.object_store_memory = float("inf")
ray.data.DataContext.get_current().execution_options.verbose_progress = True
start_time = time.time()
tokenizer = load_tokenizer(args.tokenizer)
tokenizer = load_tokenizer(args.tokenizer) if args.vocab_size is None else (lambda x: x, args.vocab_size)
logger.info(f"Total number of keys = {len(input_paths)}")
df = pd.DataFrame(input_paths, columns=["path"])
ds = ray.data.from_pandas(pd.DataFrame(input_paths, columns=["path"])).repartition(parallelism)
Expand Down
18 changes: 18 additions & 0 deletions tests/test_tokenize_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,24 @@ def test_tokenize_shuffle_simple():
total += len(x["json.gz"])
assert total == NUM_TOKENS

@pytest.mark.parametrize("content_key,NUM_TOKENS", [("npy", 4860228), ("txt", 24588), ("json:duration", 8196)])
def test_tokenize_shuffle_tar(content_key, NUM_TOKENS):
content_len = 2048

params = f"--content_key {content_key}"
if content_key == "npy":
params += " --vocab_size 16384"

exit_value = os.system(
f"python open_lm/datapreprocess/ray/tokenize_shuffle.py --input s3://dcnlp-west-test/tokenize_shuffle_test/webvid_tiny/ {params} --output test_output/ --seqlen {content_len}"
)
assert exit_value == 0
ds = wds.WebDataset("test_output/00000001.tar").decode()
total = 0
for x in ds:
assert len(x["json.gz"]) == content_len + 1
total += len(x["json.gz"])
assert total == NUM_TOKENS

def test_tokenize_shuffle_simple_do_sample():
content_len = 2048
Expand Down

0 comments on commit 89e0824

Please sign in to comment.