Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reincorporate manifest creation. #122

Merged
merged 12 commits into from
Dec 21, 2023
47 changes: 45 additions & 2 deletions open_lm/datapreprocess/ray/tokenize_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
from tqdm import tqdm
from transformers import GPTNeoXTokenizerFast

import logging


import enum
import yaml
Expand Down Expand Up @@ -293,7 +295,10 @@ def map_write_wds(batch, batch_size, folder, counter):
bio.seek(0)
token_count = ray.get(counter.increment_token_count.remote(token_count))
write_to_location(folder, tar_name, bio)
return batch

return_dict = {"shard": [tar_name.split(".")[0]], "num_sequences": [len(batch["tokens"])]}

return return_dict


def write_to_location(folder, tar_name, bio):
Expand Down Expand Up @@ -367,6 +372,27 @@ def glob_files(path, suffix=".jsonl"):
return matching_files


def get_filesystem(environment):
GeorgiosSmyrnis marked this conversation as resolved.
Show resolved Hide resolved
"""
Create a pyarrow.fs.FileSystem based on provided AWS credentials.

:param environment: Dictionary containing AWS credentials.
:return: pyarrow.fs.S3FileSystem
"""
# Extract the AWS credentials from the environment dictionary
access_key = environment.get("AWS_ACCESS_KEY_ID")
secret_key = environment.get("AWS_SECRET_ACCESS_KEY")
session_token = environment.get("AWS_SESSION_TOKEN", None) # Session token might be optional

# Create and return the S3FileSystem
return fs.S3FileSystem(
access_key=access_key,
secret_key=secret_key,
session_token=session_token,
region="us-west-2",
)


@ray.remote
class GlobalCounter:
def __init__(self):
Expand Down Expand Up @@ -484,7 +510,24 @@ def main(args):
"counter": counter,
},
batch_format="pandas",
).count()
)
GeorgiosSmyrnis marked this conversation as resolved.
Show resolved Hide resolved

def path_creation(*a, **kw):
output_path = os.path.join(args.output.strip("/"), "manifest.jsonl")
if output_path.startswith("s3://"):
return output_path[5:]
else:
return output_path

# Sort by shard name
ds = ds.repartition(1)
ds = ds.sort(key="shard")
ds.write_json(
args.output.strip("/"),
filesystem=get_filesystem(creds) if args.output.startswith("s3") else None,
block_path_provider=path_creation,
)

end_time = time.time()
duration = end_time - start_time
final_token_count = ray.get(counter.increment_token_count.remote(0))
Expand Down
15 changes: 15 additions & 0 deletions tests/test_tokenize_shuffle.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import pytest
import webdataset as wds
Expand All @@ -24,6 +25,13 @@ def test_tokenize_shuffle_simple():
total += len(x["json.gz"])
assert total == NUM_TOKENS

with open("test_output/manifest.jsonl", "rb") as f:
out = f.read()
out = [json.loads(o) for o in out.decode("utf-8").split("\n")[:-1]]

assert out[0]["shard"] == "00000001"
assert out[0]["num_sequences"] == NUM_TOKENS // (content_len + 1)


@pytest.mark.parametrize("content_key,NUM_TOKENS", [("npy", 4860228), ("txt", 24588), ("json:duration", 8196)])
def test_tokenize_shuffle_tar(content_key, NUM_TOKENS):
Expand Down Expand Up @@ -75,3 +83,10 @@ def test_tokenize_shuffle_s3_write():
total += len(x["json.gz"])
assert total == NUM_TOKENS
assert exit_value == 0

with open("test_output/manifest.jsonl", "rb") as f:
out = f.read()
out = [json.loads(o) for o in out.decode("utf-8").split("\n")[:-1]]

assert out[0]["shard"] == "00000001"
assert out[0]["num_sequences"] == NUM_TOKENS // (content_len + 1)
Loading