Skip to content

Commit

Permalink
Reincorporate manifest creation. (#122)
Browse files Browse the repository at this point in the history
* Reincorporate manifest creation.

* Formatting.

* Bugfix.

* Bugfix + add manifest to tests.

* Autodownload s3 creds for tests.

* Formatting.

* Formatting.

* Improve manifest writing.

* Bugfix for s3 test.

* Formatting.

---------

Co-authored-by: George Smyrnis <[email protected]>
  • Loading branch information
GeorgiosSmyrnis and GeorgiosSmyrnis authored Dec 21, 2023
1 parent 685bc51 commit 383ed27
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 2 deletions.
39 changes: 37 additions & 2 deletions open_lm/datapreprocess/ray/tokenize_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
from tqdm import tqdm
from transformers import GPTNeoXTokenizerFast

import logging


import enum
import yaml
Expand Down Expand Up @@ -293,7 +295,10 @@ def map_write_wds(batch, batch_size, folder, counter):
bio.seek(0)
token_count = ray.get(counter.increment_token_count.remote(token_count))
write_to_location(folder, tar_name, bio)
return batch

return_dict = {"shard": [tar_name.split(".")[0]], "num_sequences": [len(batch["tokens"])]}

return return_dict


def write_to_location(folder, tar_name, bio):
Expand Down Expand Up @@ -388,6 +393,29 @@ def get_token_counter(self):
return self.token_count


def write_manifest(jsonl_lines, args):
"Write manifest to provided output path."

output_path = os.path.join(args.output.strip("/"), "manifest.jsonl")

if output_path.startswith("s3://"):
# Use boto3 for S3 paths
s3_client = boto3.client("s3")
jsonl_content = "\n".join(json.dumps(record) for record in jsonl_lines) + "\n" # Add a newline at the end
bucket_name, s3_key = output_path[5:].split("/", 1)
response = s3_client.put_object(Bucket=bucket_name, Key=s3_key, Body=jsonl_content)
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
logging.warning(
"Failed to write manifest. Please manually include manifest by running "
"open_lm.utils.make_manifest on the tokenized data."
)
else:
with open(output_path, "w") as f:
for item in jsonl_lines:
json.dump(item, f)
f.write("\n")


def main(args):
parser = argparse.ArgumentParser()
parser.add_argument("--input", help="input path", type=str, required=True)
Expand Down Expand Up @@ -484,7 +512,14 @@ def main(args):
"counter": counter,
},
batch_format="pandas",
).count()
)

# Sort by shard name
ds = ds.repartition(1)
ds = ds.sort(key="shard")
jsonl_lines = ds.take_all()
write_manifest(jsonl_lines, args)

end_time = time.time()
duration = end_time - start_time
final_token_count = ray.get(counter.increment_token_count.remote(0))
Expand Down
15 changes: 15 additions & 0 deletions tests/test_tokenize_shuffle.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import pytest
import webdataset as wds
Expand All @@ -24,6 +25,13 @@ def test_tokenize_shuffle_simple():
total += len(x["json.gz"])
assert total == NUM_TOKENS

with open("test_output/manifest.jsonl", "rb") as f:
out = f.read()
out = [json.loads(o) for o in out.decode("utf-8").split("\n")[:-1]]

assert out[0]["shard"] == "00000001"
assert out[0]["num_sequences"] == NUM_TOKENS // (content_len + 1)


@pytest.mark.parametrize("content_key,NUM_TOKENS", [("npy", 4860228), ("txt", 24588), ("json:duration", 8196)])
def test_tokenize_shuffle_tar(content_key, NUM_TOKENS):
Expand Down Expand Up @@ -75,3 +83,10 @@ def test_tokenize_shuffle_s3_write():
total += len(x["json.gz"])
assert total == NUM_TOKENS
assert exit_value == 0

with open("test_output/manifest.jsonl", "rb") as f:
out = f.read()
out = [json.loads(o) for o in out.decode("utf-8").split("\n")[:-1]]

assert out[0]["shard"] == "00000001"
assert out[0]["num_sequences"] == NUM_TOKENS // (content_len + 1)

0 comments on commit 383ed27

Please sign in to comment.