Skip to content

Commit

Permalink
Merge branch 'main' into bootstrap_ci_pt2
Browse files Browse the repository at this point in the history
  • Loading branch information
sagadre authored Dec 10, 2023
2 parents dedd7a3 + 158124d commit 58330f9
Show file tree
Hide file tree
Showing 10 changed files with 388 additions and 95 deletions.
31 changes: 31 additions & 0 deletions open_lm/datapreprocess/metadata/rpj_lm_data.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
sources:
- source: "LMDATA"
markers: ["lmdata"]
- source: "COMMON_CRAWL"
markers: ["common_crawl"]
- source: "C4"
markers: ["c4"]
- source: "GITHUB"
markers: ["github"]
- source: "WIKIPEDIA"
markers: ["wikipedia"]
- source: "BOOKS"
markers: ["book"]
- source: "ARXIV"
markers: ["arxiv"]
- source: "STACKEXCHANGE"
markers: ["stackexchange"]
- source: "UNKNOWN"
markers: [] # No specific markers for UNKNOWN

sampling_frequencies:
COMMON_CRAWL: 0.9233485194
C4: 1.037142857
GITHUB: 0.9228813559
WIKIPEDIA: 2.26875
BOOKS: 2.094230769
ARXIV: 1.080357143
STACKEXCHANGE: 1.21
LMDATA: 1.0
UNKNOWN: 0

168 changes: 99 additions & 69 deletions open_lm/datapreprocess/ray/tokenize_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import io
import json
import os
import sys
import random
import resource
import tarfile
Expand All @@ -21,8 +22,6 @@
import numpy as np
import pandas as pd
import psutil
import pyarrow.fs as fs
import pyarrow.json
import ray
import webdataset as wds
import zstandard as zstd
Expand All @@ -43,71 +42,52 @@
from transformers import GPTNeoXTokenizerFast


import enum
import yaml
import pathlib

# Initialize an empty dictionary for sampling frequencies

DIR = pathlib.Path(__file__).parent.absolute()


def load_from_yaml(filename):
SAMPLING_FREQUENCIES = {}

with open(filename, "r") as file:
data = yaml.safe_load(file)

# Dynamically create the Sources enum based on YAML file
Sources = enum.Enum("Sources", {item["source"]: index for index, item in enumerate(data["sources"])})

# Add get_source and get_sampling_frequency methods to Sources
def get_source_dynamic(self, key):
for item in data["sources"]:
if any(marker in key for marker in item["markers"]):
return Sources[item["source"]]
return Sources.UNKNOWN

def get_sampling_frequency_dynamic(self, key):
return SAMPLING_FREQUENCIES[self.get_source(key)]

Sources.get_source = classmethod(get_source_dynamic)
Sources.get_sampling_frequency = classmethod(get_sampling_frequency_dynamic)

# Load sampling frequencies
for key, value in data["sampling_frequencies"].items():
source = Sources[key]
SAMPLING_FREQUENCIES[source] = value
return Sources, SAMPLING_FREQUENCIES


class RawFileType(enum.Enum):
JSONL = 1
ZSTD_JSONL_COMPRESSED = 2
GZIP_JSONL_COMPRESSED = 3
TAR = 4
UNKNOWN = -1


class Sources(enum.Enum):
COMMON_CRAWL = 0
C4 = 1
GITHUB = 2
WIKIPEDIA = 3
BOOKS = 4
ARXIV = 5
STACKEXCHANGE = 6
UNKNOWN = 7

@classmethod
def get_source(cls, key):
if "common_crawl" in key or "webtext" in key or "realnews" in key or "pile-cc" in key:
return cls.COMMON_CRAWL
elif "c4" in key:
return cls.C4
elif "github" in key or "dedup" in key:
return cls.GITHUB
elif "wikipedia" in key:
return cls.WIKIPEDIA
elif "book" in key:
return cls.BOOKS
elif "arxiv" in key or "s2orc" in key or "pubmed" or "phil" or "nih" or "math":
return cls.ARXIV
elif (
"stackexchange" in key
or "youtube"
or "ubuntu"
or "hn"
or "law" in key
or "europarl" in key
or "enron" in key
):
return cls.STACKEXCHANGE
else:
return cls.UNKNOWN

@classmethod
def get_sampling_frequency(cls, key):
return SAMPLING_FREQUENCIES[cls.get_source(key)]


# hard coded from Mitchell
# These are sampling frequencies for each source used to match
# the Mosaic training run on RPJ
# TODO load from JSON

SAMPLING_FREQUENCIES = {}
SAMPLING_FREQUENCIES[Sources.COMMON_CRAWL] = 0.9233485194
SAMPLING_FREQUENCIES[Sources.C4] = 1.037142857
SAMPLING_FREQUENCIES[Sources.GITHUB] = 0.9228813559
SAMPLING_FREQUENCIES[Sources.WIKIPEDIA] = 2.26875
SAMPLING_FREQUENCIES[Sources.BOOKS] = 2.094230769
SAMPLING_FREQUENCIES[Sources.ARXIV] = 1.080357143
SAMPLING_FREQUENCIES[Sources.STACKEXCHANGE] = 1.21
SAMPLING_FREQUENCIES[Sources.UNKNOWN] = 0


def jsonl_file_reader(fh: BinaryIO, content_key: str):
with io.TextIOWrapper(fh, encoding="utf-8") as text_reader:
with jsonlines.Reader(text_reader) as jsonl_reader:
Expand All @@ -129,13 +109,44 @@ def gzip_compressed_reader(fh: BinaryIO, content_key: str):
yield item[content_key]


def tar_reader(fh: BinaryIO, content_key: str):
"""
content_key: where in the tarfile to find the text/tokens. Options:
"txt" - read text file as string
"json:key" - read json[key] as string
"npy" - read numpy array as tokens
"""
content_ext = content_key.split(":")[0]
buffer = io.BytesIO(fh.read())
with tarfile.open(fileobj=buffer, mode="r") as tar:
samples = []
for member in tar.getmembers():
if member.isfile() and member.name.endswith(f".{content_ext}"):
with tar.extractfile(member) as fileobj:
if fileobj: # Ensure fileobj is not None
if content_ext == "txt":
content = fileobj.read().decode("utf-8")
elif content_ext == "json":
json_dict, json_key = json.load(fileobj), content_key.split(":")[1]
content = json_dict[json_key]
elif content_ext == "npy":
token_array = np.load(io.BytesIO(fileobj.read()), allow_pickle=True)
content = token_array.reshape(-1).tolist()
else:
raise ValueError(f"Unsupported content key extension: {content_key}")

yield content


def get_reader(file_type, content_key: str):
if file_type == RawFileType.JSONL:
return lambda x: jsonl_file_reader(x, content_key=content_key)
if file_type == RawFileType.ZSTD_JSONL_COMPRESSED:
return lambda x: zstd_compressed_reader(x, content_key=content_key)
if file_type == RawFileType.GZIP_JSONL_COMPRESSED:
return lambda x: gzip_compressed_reader(x, content_key=content_key)
if file_type == RawFileType.TAR:
return lambda x: tar_reader(x, content_key=content_key)
else:
raise Exception("Unsupported filetype")

Expand All @@ -147,6 +158,8 @@ def get_raw_filetype(key: str):
return RawFileType.ZSTD_JSONL_COMPRESSED
elif key.endswith(".jsonl.gz") or key.endswith(".json.gz"):
return RawFileType.GZIP_JSONL_COMPRESSED
elif key.endswith(".tar"):
return RawFileType.TAR
else:
logger.warning(f"Unknown filetype: {key}")
return RawFileType.UNKNOWN
Expand All @@ -160,13 +173,15 @@ def preprocess(
seqlen: int = 8192,
tokenizer=None,
do_sample: bool = False,
sources: enum.Enum = None,
):
tokenizer_fn, vocab_size = tokenizer
rng = random.Random(hash(key) + seed)
EOT = SpecialTokens.END_OF_TEXT.value % (vocab_size + len(SpecialTokens))
PAD = SpecialTokens.PAD.value % (vocab_size + len(SpecialTokens))
if do_sample:
sample_freq = Sources.get_sampling_frequency(key)
assert sources is not None
sample_freq = sources.get_sampling_frequency(key)
buffer = []
try:
file_type = get_raw_filetype(key)
Expand Down Expand Up @@ -195,15 +210,15 @@ def preprocess(
else:
yield buffer[:seqlen]
buffer = buffer[seqlen:]
if len(buffer) > 0:
yield buffer + [PAD] * (seqlen - len(buffer))
if len(buffer) > 0:
yield buffer + [PAD] * (seqlen - len(buffer))

except (IncompleteReadError, ReadTimeoutError, ResponseStreamingError) as e:
logger.error(f"There was an incomplete read error: {e} for key {key}")
return
return []


def process_keys(data, tokenizer, seqlen, seed, content_key):
def process_keys(data, tokenizer, seqlen, seed, content_key, do_sample, sources=None):
s3_client = boto3.client("s3")
path = data["path"]
bucket, key = parse_s3_path(path)
Expand All @@ -216,7 +231,8 @@ def process_keys(data, tokenizer, seqlen, seed, content_key):
seed=seed,
tokenizer=tokenizer,
content_key=content_key,
do_sample=False,
do_sample=do_sample,
sources=sources,
)
for token_buffer in token_buffers:
yield {"tokens": token_buffer}
Expand Down Expand Up @@ -372,7 +388,7 @@ def get_token_counter(self):
return self.token_count


if __name__ == "__main__":
def main(args):
parser = argparse.ArgumentParser()
parser.add_argument("--input", help="input path", type=str, required=True)
parser.add_argument(
Expand All @@ -385,16 +401,22 @@ def get_token_counter(self):
parser.add_argument("--content_key", type=str, default="text")
parser.add_argument("--seqlen", type=int, default=2048)
parser.add_argument("--tokenizer", type=str, default="EleutherAI/gpt-neox-20b")
parser.add_argument("--vocab_size", type=int, default=None) # for pre-tokenized data, don't load tokenizer
parser.add_argument("--wds_chunk_size", type=int, default=8192)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--subset", type=int, default=None)
parser.add_argument("--ray_address", type=str, default=None)
parser.add_argument("--block_size", type=str, default="10MB")
parser.add_argument("--force_parallelism", type=int, default=None)
parser.add_argument("--no_shuffle", action="store_true")
parser.add_argument("--do_sample", action="store_true")
parser.add_argument("--ray_spill_location", type=str, default="/tmp/ray_spill")
parser.add_argument("--default_dataset_yaml", type=str, default=(DIR.parent / "metadata" / "rpj_lm_data.yaml"))

args = parser.parse_args()
args = parser.parse_args(args)
Sources, SAMPLING_FREQUENCIES = load_from_yaml(args.default_dataset_yaml)
logger.info(f"SOURCES:\n {Sources}")
logger.info(f"SAMPLING_FREQUENCIES:\n{SAMPLING_FREQUENCIES}")
# configure remote spilling
creds = {k: v for k, v in os.environ.items() if k.startswith("AWS")}
runtime_env = {"env_vars": creds}
Expand All @@ -409,6 +431,7 @@ def get_token_counter(self):
for inp_folder in input_folders:
input_paths += glob_files(inp_folder, suffix=".jsonl")
input_paths += glob_files(inp_folder, suffix=".zst")
input_paths += glob_files(inp_folder, suffix=".tar")
if args.subset is not None:
input_paths = input_paths[: args.subset]
rng = random.Random(args.seed)
Expand All @@ -426,9 +449,10 @@ def get_token_counter(self):
parallelism = num_cores * num_nodes
ctx = DataContext.get_current()
ctx.use_push_based_shuffle = True
ctx.execution_options.resource_limits.object_store_memory = float("inf")
ray.data.DataContext.get_current().execution_options.verbose_progress = True
start_time = time.time()
tokenizer = load_tokenizer(args.tokenizer)
tokenizer = load_tokenizer(args.tokenizer) if args.vocab_size is None else (lambda x: x, args.vocab_size)
logger.info(f"Total number of keys = {len(input_paths)}")
df = pd.DataFrame(input_paths, columns=["path"])
ds = ray.data.from_pandas(pd.DataFrame(input_paths, columns=["path"])).repartition(parallelism)
Expand All @@ -439,6 +463,8 @@ def get_token_counter(self):
seqlen=seqlen,
seed=args.seed,
content_key=content_key,
do_sample=args.do_sample,
sources=Sources,
)
)
ds = ds.map(add_hash)
Expand Down Expand Up @@ -476,3 +502,7 @@ def get_token_counter(self):
print("Failed to retrieve memory summary")
print(traceback.format_exc())
print("")


if __name__ == "__main__":
main(sys.argv[1:])
4 changes: 2 additions & 2 deletions open_lm/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import List, Optional
from tqdm import tqdm

from .distributed import is_master
from open_lm.distributed import is_master


def remote_sync_s3(local_dir, remote_dir):
Expand Down Expand Up @@ -234,7 +234,7 @@ def log_num_checkpoints(total_steps, args):

steps_done = 0
tokens_seen = 0
next_shard_per_source = [0]
next_shard_per_source = [0 for _ in range(len(args.dataset_manifest))] if args.dataset_manifest is not None else 0
checkpoints_made = 0

if is_master(args):
Expand Down
14 changes: 7 additions & 7 deletions open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@

from open_lm.model import create_model
from open_lm.utils.transformers.hf_wrapper import create_wrapped_hf_model
from .data import get_data, get_wds_dataset
from .distributed import is_master, init_distributed_device, broadcast_object
from .logger import setup_logging
from .params import parse_args
from .scheduler import cosine_lr
from .train import train_one_epoch, evaluate_loop
from .file_utils import (
from open_lm.data import get_data, get_wds_dataset
from open_lm.distributed import is_master, init_distributed_device, broadcast_object
from open_lm.logger import setup_logging
from open_lm.params import parse_args
from open_lm.scheduler import cosine_lr
from open_lm.train import train_one_epoch, evaluate
from open_lm.file_utils import (
pt_load,
check_exists,
start_sync_process,
Expand Down
Loading

0 comments on commit 58330f9

Please sign in to comment.