Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reincorporate manifest creation. #122

Merged
merged 12 commits into from
Dec 21, 2023
39 changes: 37 additions & 2 deletions open_lm/datapreprocess/ray/tokenize_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
from tqdm import tqdm
from transformers import GPTNeoXTokenizerFast

import logging


import enum
import yaml
Expand Down Expand Up @@ -293,7 +295,10 @@ def map_write_wds(batch, batch_size, folder, counter):
bio.seek(0)
token_count = ray.get(counter.increment_token_count.remote(token_count))
write_to_location(folder, tar_name, bio)
return batch

return_dict = {"shard": [tar_name.split(".")[0]], "num_sequences": [len(batch["tokens"])]}

return return_dict


def write_to_location(folder, tar_name, bio):
Expand Down Expand Up @@ -388,6 +393,29 @@ def get_token_counter(self):
return self.token_count


def write_manifest(jsonl_lines, args):
"Write manifest to provided output path."

output_path = os.path.join(args.output.strip("/"), "manifest.jsonl")

if output_path.startswith("s3://"):
# Use boto3 for S3 paths
s3_client = boto3.client("s3")
jsonl_content = "\n".join(json.dumps(record) for record in jsonl_lines) + "\n" # Add a newline at the end
bucket_name, s3_key = output_path[5:].split("/", 1)
response = s3_client.put_object(Bucket=bucket_name, Key=s3_key, Body=jsonl_content)
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
logging.warning(
"Failed to write manifest. Please manually include manifest by running "
"open_lm.utils.make_manifest on the tokenized data."
)
else:
with open(output_path, "w") as f:
for item in jsonl_lines:
json.dump(item, f)
f.write("\n")


def main(args):
parser = argparse.ArgumentParser()
parser.add_argument("--input", help="input path", type=str, required=True)
Expand Down Expand Up @@ -484,7 +512,14 @@ def main(args):
"counter": counter,
},
batch_format="pandas",
).count()
)
GeorgiosSmyrnis marked this conversation as resolved.
Show resolved Hide resolved

# Sort by shard name
ds = ds.repartition(1)
ds = ds.sort(key="shard")
jsonl_lines = ds.take_all()
write_manifest(jsonl_lines, args)

end_time = time.time()
duration = end_time - start_time
final_token_count = ray.get(counter.increment_token_count.remote(0))
Expand Down
15 changes: 15 additions & 0 deletions tests/test_tokenize_shuffle.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import pytest
import webdataset as wds
Expand All @@ -24,6 +25,13 @@ def test_tokenize_shuffle_simple():
total += len(x["json.gz"])
assert total == NUM_TOKENS

with open("test_output/manifest.jsonl", "rb") as f:
out = f.read()
out = [json.loads(o) for o in out.decode("utf-8").split("\n")[:-1]]

assert out[0]["shard"] == "00000001"
assert out[0]["num_sequences"] == NUM_TOKENS // (content_len + 1)


@pytest.mark.parametrize("content_key,NUM_TOKENS", [("npy", 4860228), ("txt", 24588), ("json:duration", 8196)])
def test_tokenize_shuffle_tar(content_key, NUM_TOKENS):
Expand Down Expand Up @@ -75,3 +83,10 @@ def test_tokenize_shuffle_s3_write():
total += len(x["json.gz"])
assert total == NUM_TOKENS
assert exit_value == 0

with open("test_output/manifest.jsonl", "rb") as f:
out = f.read()
out = [json.loads(o) for o in out.decode("utf-8").split("\n")[:-1]]

assert out[0]["shard"] == "00000001"
assert out[0]["num_sequences"] == NUM_TOKENS // (content_len + 1)
Loading