From 2d10c865f324399e72a9ed4063ebade30296a21d Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 22 Oct 2024 13:44:34 -0400 Subject: [PATCH 01/18] Remove rope_theta arg from recent external PR --- fms_fsdp/utils/config_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index da5c6b40..2eb03553 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -126,7 +126,6 @@ def get_model_config(model_variant): nlayers=24, hidden_grow_factor=8 / 3, max_expected_seq_len=4096, - rope_theta=500000.0, ) elif model_variant == "llama3_70b": llama_config = LLaMAConfig( From 6b78e4d0d7f0efdc986bb8b391f9f6715aafdc1d Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 29 Apr 2025 15:26:12 -0400 Subject: [PATCH 02/18] Mamba cp data 500k (#12) * Add code of conduct (#132) Add Foundation Model Stack community code of conduct. Signed-off-by: Sahdev Zala * minimal cp * print model sanity check * more print tests * better print test * print test fix * use foreach = false by default * comment * remove foreach option * cp_impl * context_parallel -> cp * cp_in_node * cp_in_node default False * cp_in_node -> cp_over_world * add else case * print check * rm print check * explicit device meshes * comment * Pass dp_degree to dataloader * Apply CP chunking * add cp_{attn,mamba}_impl configs * allreduce -> allgather typo fix * Corrected dp ws * Diag save * Close brackets * Better save * Also save per-token loss * grad norm print test * fix throughput computation for cp * rm local grad norm print * another local grad norm print test * print loss, too * add local num params print * rm test code * Add 8x cfg * Upstream fhandler * Upstream fhandler / edge case fixes * Rope theta name * Add 32x * rework mesh logic * low_cpu_fsdp for mamba * wrap embedding and lm head for mamba * fms_to_hf_mamba_transformers.py * remove "sharded" in hf ckpt util * copy over hf conversion script * Add docslicedataset * Hard disable countfile (mixed dataset) * Add 500k cfg and zloss * Diag skip dslice * Diag swap back buffers * Readd loader changes * Diag dslice off * Revert * Diag print * Swap back buffers, doc fragging * Add constant sched * Up slice rate * Soft filtering @8k * Add 16x * Add filter_exp as arg * Passthrough filter_exp from cfg * Add filter_exp and target_doclen * Passthrough target_doclen --------- Signed-off-by: Sahdev Zala Co-authored-by: Sahdev Zala Co-authored-by: Garrett Goon --- code-of-conduct.md | 3 + fms_fsdp/config/training.py | 8 + fms_fsdp/utils/config_utils.py | 100 +++++++++++ fms_fsdp/utils/dataloader_utils.py | 39 +++- fms_fsdp/utils/dataset_utils.py | 154 ++++++++++++++-- fms_fsdp/utils/train_utils.py | 24 ++- fms_to_hf_mamba_transformers.py | 280 +++++++++++++++++++++++++++++ main_training_mamba.py | 101 +++++++++-- 8 files changed, 663 insertions(+), 46 deletions(-) create mode 100644 code-of-conduct.md create mode 100644 fms_to_hf_mamba_transformers.py diff --git a/code-of-conduct.md b/code-of-conduct.md new file mode 100644 index 00000000..a5a920fc --- /dev/null +++ b/code-of-conduct.md @@ -0,0 +1,3 @@ +# Foundation Model Stack Community Code of Conduct + +Please refer to [Foundation Model Stack Community Code of Conduct](https://github.com/foundation-model-stack/foundation-model-stack/blob/main/code-of-conduct.md). diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py index 1d072958..8bdda60c 100644 --- a/fms_fsdp/config/training.py +++ b/fms_fsdp/config/training.py @@ -26,6 +26,8 @@ class train_config: strip_tokens: str = "" logical_shards: int = 1024 num_workers: int = 1 + filter_exp: int = 2 + target_doclen: int = 8192 # fsdp policies sharding_strategy: str = "hsdp" @@ -72,3 +74,9 @@ class train_config: stage2_prompt_length: int = 64 stage2_batch_size: int = 96 stage2_seq_length: int = 256 + + # context parallel + cp: bool = False + cp_mamba_impl: str = "allgather" # "allgather" or "serial" + cp_attn_impl: str = "zigzag" # "zigzag" or "ring" + cp_over_world: bool = False diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index 6770622c..00e4d786 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -182,6 +182,106 @@ def get_model_config(model_variant): "pad_vocab_size_multiple": 16, "tie_embeddings": False, } + elif model_variant == "mamba_9.8b_8x": + model_config = { + "d_model": 4096, + "d_intermediate": 14336, + "n_layer": 32, + "vocab_size": 128256, + "ssm_cfg": {"layer": "Mamba2"}, + "attn_layer_idx": [9, 18, 27], + "attn_cfg": { + "causal": True, + "d_conv": 0, + "head_dim": 128, + "num_heads": 32, + "num_heads_kv": 8, + "out_proj_bias": False, + "qkv_proj_bias": False, + "rotary_emb_dim": 64, + "rotary_emb_base": 80000, + }, + "rms_norm": True, + "residual_in_fp32": True, + "fused_add_norm": True, + "pad_vocab_size_multiple": 16, + "tie_embeddings": False, + } + elif model_variant == "mamba_9.8b_16x": + model_config = { + "d_model": 4096, + "d_intermediate": 14336, + "n_layer": 32, + "vocab_size": 128256, + "ssm_cfg": {"layer": "Mamba2"}, + "attn_layer_idx": [9, 18, 27], + "attn_cfg": { + "causal": True, + "d_conv": 0, + "head_dim": 128, + "num_heads": 32, + "num_heads_kv": 8, + "out_proj_bias": False, + "qkv_proj_bias": False, + "rotary_emb_dim": 64, + "rotary_emb_base": 160000, + }, + "rms_norm": True, + "residual_in_fp32": True, + "fused_add_norm": True, + "pad_vocab_size_multiple": 16, + "tie_embeddings": False, + } + elif model_variant == "mamba_9.8b_32x": + model_config = { + "d_model": 4096, + "d_intermediate": 14336, + "n_layer": 32, + "vocab_size": 128256, + "ssm_cfg": {"layer": "Mamba2"}, + "attn_layer_idx": [9, 18, 27], + "attn_cfg": { + "causal": True, + "d_conv": 0, + "head_dim": 128, + "num_heads": 32, + "num_heads_kv": 8, + "out_proj_bias": False, + "qkv_proj_bias": False, + "rotary_emb_dim": 64, + "rotary_emb_base": 320000, + }, + "rms_norm": True, + "residual_in_fp32": True, + "fused_add_norm": True, + "pad_vocab_size_multiple": 16, + "tie_embeddings": False, + } + elif model_variant == "mamba_9.8b_500k": + model_config = { + "d_model": 4096, + "d_intermediate": 14336, + "n_layer": 32, + "vocab_size": 128256, + "ssm_cfg": {"layer": "Mamba2"}, + "attn_layer_idx": [9, 18, 27], + "attn_cfg": { + "causal": True, + "d_conv": 0, + "head_dim": 128, + "num_heads": 32, + "num_heads_kv": 8, + "out_proj_bias": False, + "qkv_proj_bias": False, + "rotary_emb_dim": 64, + "rotary_emb_base": 1280000, + }, + "rms_norm": True, + "residual_in_fp32": True, + "fused_add_norm": True, + "pad_vocab_size_multiple": 16, + "tie_embeddings": False, + } else: raise ValueError(f"model variant {model_variant} not supported.") diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py index 4b811d6d..b954be9a 100644 --- a/fms_fsdp/utils/dataloader_utils.py +++ b/fms_fsdp/utils/dataloader_utils.py @@ -5,6 +5,7 @@ AutoHandler, BufferDataset, CheckpointDataset, + DocSliceDataset, ParquetHandler, PreloadBufferDataset, PreprocessDataset, @@ -57,7 +58,7 @@ def __iter__(self): return torch.utils.data.DataLoader(data, batch_size=cfg.batch_size) -def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): +def get_data_loader(cfg, rank, world_size, dp_degree, postprocess=[causal_lm]): """ Pytorch dataloader for stateful, distributed, and rescalable causal language model (CLM) training. Assumes underlying data is sequences of integer values. @@ -75,7 +76,15 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): the order provided by the user. For CLM training, use postprocess=[causal_lm]. """ - datasets, weights = parse_data_args(cfg.datasets, cfg.weights) + do_cp = False + if dp_degree != world_size: + do_cp = True + cp_worldsize = world_size // dp_degree + cp_rank = rank % cp_worldsize + world_size = dp_degree + rank = rank // cp_worldsize + + datasets, weights, cols = parse_data_args(cfg.datasets, cfg.weights, cfg.col_name) # Base streaming dataset. Returns doc chunks in sequence. # Implements dataset sampling and rescalability. @@ -87,9 +96,9 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): cfg.file_type in _handler_map ), f"File type {cfg.file_type} is not recognized ({list(_handler_map.keys())})" if cfg.file_type == "hf_parquet" or cfg.file_type == "auto": - filehandler = _handler_map[cfg.file_type](cfg.tokenizer_path, cfg.col_name) + filehandler = _handler_map[cfg.file_type](cfg.tokenizer_path, cols) else: - filehandler = _handler_map[cfg.file_type] + filehandler = _handler_map[cfg.file_type](cols) # Base reader layer data = StreamingDocDataset( cfg.data_path, @@ -99,8 +108,9 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): cfg.eos_token, bos_token=cfg.bos_token, strip_tokens=set(droplist), - min_length=3, + min_length=cfg.target_doclen, seed=cfg.seed, + filter_exp=cfg.filter_exp, ) # Add rescaling/resharding data = ScalableShardDataset( @@ -126,13 +136,25 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): pack_hard=True, ) # Shuffle outputs in length 10k buffer. Consecutive lines appear 10k steps apart on average. - data = PreloadBufferDataset(data, 10000) + data = PreloadBufferDataset(data, 1000) + # Slice and rearrange docs to force long-context retrieval + data = DocSliceDataset( + data, + cfg.eos_token, + slice_rate=.75, + ) # Apply desired postprocessing steps in sequence data = PreprocessDataset(data, torch.IntTensor) for p in postprocess: data = PreprocessDataset(data, p) + # Apply CP chunking if using CP + if do_cp: + def chunk(x): + return x[(cp_rank*x.size(0))//cp_worldsize : ((cp_rank+1)*x.size(0))//cp_worldsize] + data = PreprocessDataset(data, lambda x: (chunk(x[0]), chunk(x[1]))) + # Enable auto-saving data = CheckpointDataset( data, @@ -146,7 +168,7 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): ) -def parse_data_args(datas, weights): +def parse_data_args(datas, weights, cols): # Convert csv inputs into corresponding lists of values def splitstrip(x): if isinstance(x, str): @@ -160,4 +182,5 @@ def splitstrip(x): datas = splitstrip(datas) weights = [float(x) for x in splitstrip(weights)] - return datas, weights + cols = splitstrip(cols) + return datas, weights, cols diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index aedc5862..ff9f5c80 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -343,7 +343,7 @@ class ArrowHandler(_ShardFileHandler): Non-standard data format, though. """ - def __init__(self, col_name: str = "tokens"): + def __init__(self, col_name: List[str] = ["tokens"]): self.col_name = col_name def is_legal(self, filepath: str): @@ -356,7 +356,13 @@ def length(self, path: str): return self.open(path).num_record_batches def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set): - doc = reader.get_batch(index)[self.col_name] + frame = reader.get_batch(index) + doc = None + for name in self.col_name: + if name in frame.column_names: + doc = frame[name] + break + assert doc is not None, f"None of column names {self.col_name} found in file headers {frame.column_names}" if len(doc) > 0 and doc[0].as_py() in drop_tokens: doc = doc.slice(1, len(doc) - 1) # Recheck len for edge case where doc=[eos] @@ -376,7 +382,7 @@ class ParquetHandler(_ShardFileHandler): before getting/slicing. However, this is a standard and widely-used data format. """ - def __init__(self, tokenizer_path: str, col_name: str = "text"): + def __init__(self, tokenizer_path: str, col_name: List[str] = ["text"]): self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) self.col_name = col_name @@ -384,9 +390,14 @@ def is_legal(self, filepath: str): return "parquet" in os.path.splitext(filepath)[1] def open(self, path: str): - return pq.read_pandas(path, columns=[self.col_name], partitioning=None)[ - self.col_name - ] + names = pq.read_metadata(path).schema.names + match = None + for name in self.col_name: + if name in names: + match = name + break + assert match is not None, f"None of column names {self.col_name} found in file headers {names}" + return pq.read_pandas(path, columns=[match], partitioning=None)[match] def length(self, path: str): return pq.read_metadata(path).num_rows @@ -405,9 +416,9 @@ def slice(self, doc: List, index: int, n_pull: int) -> List: class AutoHandler(_ShardFileHandler): - def __init__(self, tokenizer_path: str, col_name: str = "text"): + def __init__(self, tokenizer_path: str, col_name: List[str] = ["text", "contents", "tokens"]): self.PHandler = ParquetHandler(tokenizer_path, col_name) - self.AHandler = ArrowHandler() + self.AHandler = ArrowHandler(col_name) self.current = _ShardFileHandler() def is_legal(self, filepath: str): @@ -694,6 +705,87 @@ def load_state_dict(self, state_dicts, sharded_input=False): # Manually set buffer size self.buffer_size = len(self.buffer) return sharded_dicts + + +class DocSliceDataset(_WrapperDataset): + """ + Wrapper for a StatefulDataset that implements document slicing. + ... + Args + ---- + dataset : _StatefulDataset + Fully instantiated dataset + delimiter_token : int + Value used for delimiter + slice_rate : float + Proportion of documents to slice + overlap : int + Number of tokens to overlap for slice retrieval + """ + + def __init__(self, dataset: _StatefulDataset, delimiter_token: int, slice_rate: float = 0.5, overlap: int = 3): + super().__init__(dataset) + self.g_state = None + self.generator = torch.Generator().manual_seed(self.rank) + self.state_params = ["g_state"] + self.delimiter = delimiter_token + self.slicerate = slice_rate + self.overlap = overlap + + def __iter__(self): + dataset = iter(self.dataset) + while True: + inp = next(dataset) + inplen = len(inp) + doclist = [] + last_delim = 0 + for i in range(len(inp)): + if inp[i] == self.delimiter: + doclist.append(inp[last_delim:i]) + last_delim = i+1 + doclist.append(inp[last_delim:]) + nslice = int((len(doclist)-2)*self.slicerate) + if len(doclist) < 3 or nslice < 2: + yield inp + else: + begin = doclist[0] + end = doclist[-1] + slice = doclist[1:1+nslice] + unslice = doclist[1+nslice:-1] + sliced = [] + for doc in slice: + assert len(doc)//3 > self.overlap, f"Doc length {len(doc)} too small for random slice with desired overlap {self.overlap}: {[len(begin), [len(x) for x in slice], [len(x) for x in unslice], len(end)]}" + i = torch.randint(0, len(doc)//3, [1], generator=self.generator).item() + len(doc)//3 + sliced.append([doc[:i], doc[i-self.overlap:]]) + slice = sliced + doclist = [slice[0][0], slice[1][0], slice[0][1], slice[1][1]] + for docpair in slice[2:]: + inds = torch.randperm(len(doclist)+1, generator=self.generator)[:2].tolist() + inds.sort() + inds[1] += 1 + doclist = doclist[:inds[0]] + [docpair[0]] + doclist[inds[0]:inds[1]-1] + [docpair[1]] + doclist[inds[1]-1:] + for doc in unslice: + i = torch.randint(0, len(doclist)+1, [1], generator=self.generator).item() + doclist = doclist[:i] + [doc] + doclist[i:] + out = begin + [self.delimiter] + for doc in doclist: + out = out + doc + out.append(self.delimiter) + out = out + end + yield out[:inplen] + + def state_dict(self): + # Write generator state manually + self.g_state = self.generator.get_state() + out = super().state_dict() + return out + + def load_state_dict(self, state_dicts, sharded_input=False): + sharded_dicts = super().load_state_dict(state_dicts, sharded_input) + # Manually set generator state if it exists + if self.g_state is not None: + self.generator.set_state(self.g_state) + return sharded_dicts class BufferDataset(_WrapperDataset): @@ -860,6 +952,7 @@ def __init__( min_length: int = 1, max_chunksize: int = 1024, verbose: bool = False, + filter_exp: int = 2, ): super().__init__(datapath, rank, worldsize) self.seed = seed @@ -872,6 +965,7 @@ def __init__( self.bos = bos_token self.drop = strip_tokens self.verbose = verbose + self.filter_exp = filter_exp self.docset: List[ Any ] = [] # map of doc indices to (shardid, min docid, max docid) @@ -885,6 +979,7 @@ def __init__( self.tokens_seen = 0 self.docs_seen = 0 self.percent_seen = 0 + self.has_yielded = False self.state_params = [ "dataset", @@ -895,6 +990,7 @@ def __init__( "docs_seen", "percent_seen", "lcg_state", + "g_state", ] # Setup flags @@ -902,6 +998,9 @@ def __init__( self._len = 0 self.dataset = "" self.lcg_state = 0 + self.g_state = None + + self.g = None def setup(self): """ @@ -925,6 +1024,8 @@ def setup(self): for root, dirs, files in os.walk(datapath, topdown=False) for name in files if self.filehandler.is_legal(os.path.join(root, name)) + and os.path.getsize(os.path.join(root, name)) > 1_000_000 + # 1mb minimum file size to prevent empty files ] shards.sort() # Ensure consistent sharding across machines start_frag = (self.rank * self.worldsize * len(shards)) // self.worldsize @@ -939,12 +1040,12 @@ def setup(self): # Assemble length of each owned shard file countfiles = [] - if os.path.exists(os.path.join(pardir, "meta")): - countfiles = [ - x - for x in os.listdir(os.path.join(pardir, "meta")) - if "counts" in x and "csv" in x - ] + # if os.path.exists(os.path.join(pardir, "meta")): + # countfiles = [ + # x + # for x in os.listdir(os.path.join(pardir, "meta")) + # if "counts" in x and "csv" in x + # ] doc_counts = {} if len(countfiles) > 0: # Count file exists, use it @@ -953,8 +1054,8 @@ def setup(self): reader = csv.DictReader(csvfile) for row in reader: fullpath = row["dataset/filename"] - prefix = fullpath.find("/" + dataset) + 1 - if prefix > 0: + prefix = fullpath.find(dataset) + if prefix >= 0: key = fullpath[prefix + len(dataset) + 1 :] doc_counts[key] = int(row["documents"]) else: @@ -1002,6 +1103,7 @@ def setup(self): random.shuffle(self.docset) # Setup doc shuffle - same guarantee self.lcg_state = seed + self.g = torch.Generator().manual_seed(self.rank) def _get_docid(self, i): """ @@ -1097,7 +1199,8 @@ def __iter__(self): if len(doc) == 0: continue doclen = len(doc) + 1 if self.bos is None else len(doc) + 2 - if doclen >= self.min_length: + keep_chance = (doclen/self.min_length)**self.filter_exp + if torch.rand(1, generator=self.g).item() < keep_chance: n_chunks = math.ceil(doclen / self.chunksize) for j in range(n_chunks): if i == 0 and j < residual_chunks: @@ -1110,6 +1213,7 @@ def __iter__(self): self.percent_seen = ( self.docs_seen * 100 / (self._len + 1e-9) ) + self.has_yielded = True yield self._construct_chunk(j, doc, n_chunks) # Advance RNG state @@ -1126,12 +1230,23 @@ def __iter__(self): if len(doc) == 0: continue doclen = len(doc) + 1 if self.bos is None else len(doc) + 2 - if doclen >= self.min_length: + keep_chance = (doclen/self.min_length)**self.filter_exp + if torch.rand(1, generator=self.g).item() < keep_chance: n_chunks = math.ceil(doclen / self.chunksize) for j in range(residual_chunks): self.chunk_index = j + self.has_yielded = True yield self._construct_chunk(j, doc, n_chunks) + # Check that epoch was non-empty + assert self.has_yielded, f"Empty logical shard detected: {self.dataset, self.docset}" + + def state_dict(self): + # Write generator state manually + self.g_state = self.g.get_state() + out = super().state_dict() + return out + def load_state_dict(self, state_dicts, sharded_input=False): self.setup() assert ( @@ -1142,6 +1257,9 @@ def load_state_dict(self, state_dicts, sharded_input=False): assert ( d == self.dataset ), f"Dataset mismatch: checkpoint contains {self.dataset}, expected {d}" + # Manually set generator state if it exists + if self.g_state is not None: + self.g.set_state(self.g_state) return out diff --git a/fms_fsdp/utils/train_utils.py b/fms_fsdp/utils/train_utils.py index ef421f6f..bbd8d54b 100644 --- a/fms_fsdp/utils/train_utils.py +++ b/fms_fsdp/utils/train_utils.py @@ -2,7 +2,6 @@ from dataclasses import asdict from functools import partial - try: import packaging.version except ImportError: @@ -30,6 +29,7 @@ def train( checkpointer, start_step, tokens_seen, + cp_degree: int = 1, ): if cfg.tracker: if cfg.tracker not in ["wandb", "aim"]: @@ -44,7 +44,7 @@ def train( except ImportError: raise ImportError("tracker is set to wandb but wandb is not installed.") if rank == 0: - print(f"--> wandb is enabled!") + print("--> wandb is enabled!") try: wandb.init( project=project_name, @@ -64,7 +64,7 @@ def train( except ImportError: raise ImportError("tracker is set to aim but aim is not installed.") if rank == 0: - print(f"--> aim is enabled!") + print("--> aim is enabled!") run = Run( experiment=project_name, repo=tracker_dir, @@ -89,8 +89,9 @@ def train( output = output.logits if hasattr(output, "logits") else output ce_loss = torch.nn.CrossEntropyLoss() loss = ce_loss(output.view(-1, output.size(-1)), label.view(-1).long()) - + loss = loss + .0001 * torch.logsumexp(output, dim=-1).pow(2).mean() loss.backward() + ddp_stats[1] += model.clip_grad_norm_(cfg.grad_clip_thresh).item() optimizer.step() scheduler.step() @@ -108,7 +109,11 @@ def train( elapsed_time = time.time() - loop_start world_size = int(os.environ["WORLD_SIZE"]) new_tokens_seen = ( - (batch_idx - start_step) * world_size * cfg.batch_size * cfg.seq_length + (batch_idx - start_step) + * world_size + * cfg.batch_size + * cfg.seq_length + // cp_degree ) if rank == 0: total_tokens_seen = tokens_seen + new_tokens_seen @@ -118,10 +123,10 @@ def train( current_step_time = (time.time() - start) / cfg.report_interval overall_step_time = elapsed_time / (batch_idx - start_step) current_throughput = int( - cfg.batch_size * cfg.seq_length / current_step_time + cfg.batch_size * cfg.seq_length / cp_degree / current_step_time ) overall_throughput = int( - cfg.batch_size * cfg.seq_length / overall_step_time + cfg.batch_size * cfg.seq_length / cp_degree / overall_step_time ) reserved_mem = torch.cuda.max_memory_reserved( device=torch.cuda.current_device() @@ -145,6 +150,7 @@ def train( "overall token per day:", int(new_tokens_seen / elapsed_time * 3600 * 24), ) + print(f"Total tok/step: {world_size * cfg.batch_size * cfg.seq_length}") if cfg.tracker: vals_to_track = { "learning rate": current_lr, @@ -201,11 +207,11 @@ def get_mixed_precision_policy(cfg, rank): if bf16_ready: mixed_precision_policy = bfSixteen if rank == 0: - print(f"bFloat16 enabled for mixed precision - using bfSixteen policy") + print("bFloat16 enabled for mixed precision - using bfSixteen policy") else: mixed_precision_policy = fpSixteen if rank == 0: - print(f"FP16 enabled") + print("FP16 enabled") else: mixed_precision_policy = None diff --git a/fms_to_hf_mamba_transformers.py b/fms_to_hf_mamba_transformers.py new file mode 100644 index 00000000..57b3ec80 --- /dev/null +++ b/fms_to_hf_mamba_transformers.py @@ -0,0 +1,280 @@ +# coding=utf-8 +# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Modified from src/transformers/models/bamba/convert_mamba_ssm_checkpoint.py +""" + +"""This script can be used to convert checkpoints provided in the `mamba_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed.""" + +import argparse +import json +import os +import re +from os import path +from typing import Dict, Optional, Union + +import torch +from huggingface_hub import split_torch_state_dict_into_shards +from safetensors.torch import save_file + +from transformers import AutoTokenizer +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME + +from transformers.models.bamba import BambaConfig + + +def convert_state_dict_from_mamba_ssm(original_sd: Dict) -> Dict[str, torch.Tensor]: + state_dict = {} + + for orig_k, param in original_sd.items(): + k = orig_k.replace("backbone", "model") + + # for embeddings + k = k.replace("embedding", "embed_tokens") + + # for mixer + k = k.replace("mixer", "mamba") + + # for final layernorm + k = k.replace("norm_f", "final_layernorm") + + # for block layernorm + k = re.sub(r"(\d+)\.norm\.", r"\1.input_layernorm.", k) + k = re.sub(r"(\d+)\.norm2\.", r"\1.pre_ff_layernorm.", k) + + # for mlp + k = k.replace("mlp.fc2", "feed_forward.down_proj") + + if "mlp.fc1" in k: + param, param2 = torch.chunk(param, 2, dim=0) + k2 = k.replace("mlp.fc1", "feed_forward.gate_proj") + state_dict[k2] = param2 + k = k.replace("mlp.fc1", "feed_forward.up_proj") + + if ("in_proj" in k and orig_k.replace("in_proj", "conv1d") in original_sd) or ( + "out_proj" in k and orig_k.replace("out_proj", "conv1d") in original_sd + ): + # then this must be a mamba + pass + else: + # for attn + # - because mixer was replaced to mamba above + k = k.replace("mamba.out_proj", "self_attn.o_proj") + if "mamba.in_proj" in k: + m, n = param.shape + d = (m - n) // 2 + param, param2, param3 = torch.split(param, [n, d, d], dim=0) + k2 = k.replace("mamba.in_proj", "self_attn.k_proj") + state_dict[k2] = param2 + k2 = k.replace("mamba.in_proj", "self_attn.v_proj") + state_dict[k2] = param3 + k = k.replace("mamba.in_proj", "self_attn.q_proj") + + state_dict[k] = param + + return state_dict + + +# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py +def convert_ssm_config_to_hf_config( + config_ssm: Dict, + **kwargs, +) -> BambaConfig: + """Convert a config from mamba_ssm to a BambaConfig from here.""" + hf_config: BambaConfig = BambaConfig(**kwargs) + + hf_config.architectures = ["BambaForCausalLM"] + + # Set important values from config and recalculate other resulting entries + hf_config.hidden_size = config_ssm["d_model"] + hf_config.intermediate_size = config_ssm["d_intermediate"] + hf_config.mamba_n_heads = (hf_config.hidden_size * hf_config.mamba_expand) // hf_config.mamba_d_head + hf_config.num_hidden_layers = config_ssm["n_layer"] + hf_config.tie_word_embeddings = config_ssm["tie_embeddings"] + + # currently this script assumes config_ssm belongs to v2 + if config_ssm["ssm_cfg"].get("layer") != "Mamba2": + raise ValueError("Conversion script only supports Mamba2") + + # Set attention values + attn_cfg = config_ssm.get("attn_cfg") + if attn_cfg: + assert attn_cfg["causal"], "Only support non-causal attention." + assert not attn_cfg["qkv_proj_bias"], "Only support no qkv bias." + assert not attn_cfg["out_proj_bias"], "Only support no out bias." + hf_config.attn_rotary_emb = attn_cfg["rotary_emb_dim"] + hf_config.num_attention_heads = attn_cfg["num_heads"] + hf_config.num_key_value_heads = attn_cfg["num_heads_kv"] + hf_config.rope_theta = attn_cfg["rotary_emb_base"] + + attention_layer_indices = config_ssm.get("attn_layer_idx") + if attention_layer_indices: + hf_config.attn_layer_indices = attention_layer_indices + + # Padded vocab size, mostly of 16 but 32 is also very common in different models + vocab_size = config_ssm["vocab_size"] + pad_vocab_size_multiple = config_ssm["pad_vocab_size_multiple"] + if (vocab_size % pad_vocab_size_multiple) != 0: + vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) + hf_config.vocab_size = vocab_size + + return hf_config + + +def save_single_safetensor( + state_dict: Dict, + save_directory: str, + metadata: Dict, +): + save_file( + state_dict, + os.path.join(save_directory, SAFE_WEIGHTS_NAME), + metadata, + ) + + +def save_sharded_safetensors( + state_dict: Dict, + save_directory: str, + metadata: Dict, + max_shard_size: Union[int, str] = "5GB", +): + filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace( + ".safetensors", "{suffix}.safetensors" + ) + state_dict_split = split_torch_state_dict_into_shards( + state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size + ) + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + # Save the index + with open(os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + + filename_to_tensors = state_dict_split.filename_to_tensors.items() + for shard_file, tensors in filename_to_tensors: + shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors} + save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata) + + +# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py +def convert_mamba_ssm_checkpoint_file_to_huggingface_model_file( + mamba_ssm_checkpoint_path: str, + precision: str, + output_dir: str, + tokenizer_path: Optional[str] = None, + save_model: Union[bool, str] = True, +) -> None: + # load tokenizer if provided, this will be used to set the + # token_ids in the config file + token_ids = {} + if tokenizer_path: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + for key in [ + "bos_token_id", + "eos_token_id", + "pad_token_id", + ]: + id = getattr(tokenizer, key, None) + if id: + token_ids[key] = id + + # there are some configs unsettable by mamba_ssn config, so + # if there are changes from the defaults, have to pass them into + # the function + unsettables = { + "mamba_d_head": 64, + "mamba_d_state": 128, + "mamba_n_groups": 1, + "rms_norm_eps": 1e-5, + } + + # Load and save config based on name + config_path = path.join(mamba_ssm_checkpoint_path, "config.json") + with open(config_path, "r", encoding="utf-8") as json_file: + config = json.load(json_file) + + # convert the config + hf_config = convert_ssm_config_to_hf_config( + config_ssm=config, + **token_ids, + **unsettables, + ) + hf_config.save_pretrained(output_dir) + + # Load state dict of the original model and transfer to hf model + state_dict = torch.load( + path.join(mamba_ssm_checkpoint_path, "pytorch_model.bin"), + map_location="cpu", + weights_only=True, + ) + # FIXME: allow other parameters to pass in + state_dict = convert_state_dict_from_mamba_ssm(state_dict) + + # Save new model to pytorch_dump_path + dtype = torch.float32 if precision == "fp32" else (torch.bfloat16 if precision == "bf16" else torch.float16) + + save_file_fn = None + if isinstance(save_model, bool) and save_model: + save_file_fn = save_single_safetensor + elif isinstance(save_model, str) and save_model == "sharded": + save_file_fn = save_sharded_safetensors + + if save_file_fn: + save_file_fn({k: v.to(dtype) for k, v in state_dict.items()}, output_dir, metadata={"format": "pt"}) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-i", + "--mamba_ssm_checkpoint_directory", + type=str, + required=True, + help="Path to a directory containing the `pytorch_model.bin` mamba_ssm checkpoint file to be converted.", + ) + parser.add_argument( + "-p", + "--precision", + type=str, + default="fp16", + required=True, + choices=("fp32", "fp16", "bf16"), + help="The precision the model will be saved in. Select from fp32, fp16 or bf16.", + ) + parser.add_argument( + "-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to." + ) + parser.add_argument( + "-t", + "--tokenizer_model_path", + type=str, + default=None, + required=False, + help="Path to a the tokenizer file.", + ) + args = parser.parse_args() + + convert_mamba_ssm_checkpoint_file_to_huggingface_model_file( + args.mamba_ssm_checkpoint_directory, + args.precision, + args.output_dir, + args.tokenizer_model_path, + ) + diff --git a/main_training_mamba.py b/main_training_mamba.py index 3619ea25..abab8ba0 100644 --- a/main_training_mamba.py +++ b/main_training_mamba.py @@ -4,12 +4,15 @@ import fire import torch +import torch.nn as nn import torch.optim as optim from mamba_ssm.models.config_mamba import MambaConfig from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel from mamba_ssm.modules.block import Block from torch import distributed as dist +from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.wrap import CustomPolicy from torch.optim.lr_scheduler import LambdaLR from fms_fsdp import config @@ -51,20 +54,82 @@ def main(**kwargs): Path.home(), ".triton", "cache", str(local_rank) ) - # get policy + # get policy. NOTE: @goon - overriding {wrapping_policy, param_init_fn} below block = Block ( mixed_precision_policy, - wrapping_policy, + _, sharding_strategy_policy, apply_selective_ac, - param_init_fn, + _, # NOTE: @goon - We'll override param_init_fn for mamba below ) = get_policies(cfg, rank, block) + if cfg.low_cpu_fsdp: + # NOTE: @goon - the params will be junk after using this. Only intended to be used in + # conjunction with loading proper weights from a checkpoint. + def param_init_fn(module): + module.to_empty(device=torch.cuda.current_device()) + else: + param_init_fn = None + + # Meshes for FSDP and CP. NOTE: @goon - Getting hangs and/or OOMs if I don't explicitly specify + # the FSDP mesh when using 4+ nodes with HSDP + in-node-CP. + def get_1D_world_mesh(world_size: int) -> DeviceMesh: + mesh = dist.device_mesh.init_device_mesh("cuda", (world_size,)) + return mesh + + def get_2D_world_mesh(world_size: int) -> DeviceMesh: + num_gpu_per_node = torch.cuda.device_count() + assert world_size % num_gpu_per_node == 0 + mesh = dist.device_mesh.init_device_mesh( + "cuda", + (world_size // num_gpu_per_node, num_gpu_per_node), + mesh_dim_names=("inter_node", "intra_node"), + ) + return mesh + + requires_2d_mesh = (cfg.sharding_strategy == "hsdp") or ( + cfg.cp and not cfg.cp_over_world + ) + if requires_2d_mesh: + mesh = get_2D_world_mesh(world_size) + fsdp_mesh = mesh + cp_mesh = mesh["intra_node"] if cfg.cp else None + else: + mesh = get_1D_world_mesh(world_size) + fsdp_mesh = mesh + cp_mesh = mesh if cfg.cp else None + + if cfg.cp: + cp_degree = world_size if cfg.cp_over_world else torch.cuda.device_count() + else: + cp_degree = 1 + + dp_degree = world_size // cp_degree # get model config_data = get_model_config(cfg.model_variant) mamba_config = MambaConfig(**config_data) - model = MambaLMHeadModel(mamba_config) + + if cfg.low_cpu_fsdp: + with torch.device("meta"): + model = MambaLMHeadModel( + mamba_config, + cp_mesh=cp_mesh if cfg.cp else None, + cp_mamba_impl=cfg.cp_mamba_impl if cfg.cp else None, + cp_attn_impl=cfg.cp_attn_impl if cfg.cp else None, + ) + else: + model = MambaLMHeadModel( + mamba_config, + cp_mesh=cp_mesh if cfg.cp else None, + cp_mamba_impl=cfg.cp_mamba_impl if cfg.cp else None, + cp_attn_impl=cfg.cp_attn_impl if cfg.cp else None, + ) + + def lambda_fn(module: nn.Module): + return isinstance(module, (Block, nn.Embedding)) or module is model.lm_head + + wrapping_policy = CustomPolicy(lambda_fn) if rank == 0: total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) @@ -74,7 +139,7 @@ def main(**kwargs): if rank == 0: print("Constructing datasets...") if not cfg.use_dummy_dataset: - train_loader = get_data_loader(cfg, rank, world_size) + train_loader = get_data_loader(cfg, rank, world_size, dp_degree) else: train_loader = get_dummy_loader(cfg, rank, world_size) if rank == 0: @@ -83,6 +148,7 @@ def main(**kwargs): # FSDP model = FSDP( model, + device_mesh=fsdp_mesh, auto_wrap_policy=wrapping_policy, mixed_precision=mixed_precision_policy, sharding_strategy=sharding_strategy_policy, @@ -91,24 +157,29 @@ def main(**kwargs): limit_all_gathers=True, param_init_fn=param_init_fn, ) + if rank == 0: + print(model) # fsdp activation checkpointing if cfg.fsdp_activation_checkpointing: if rank == 0: - print(f"--> applying FSDP activation checkpointing...") + print("--> applying FSDP activation checkpointing...") apply_selective_ac(model, p=cfg.selective_checkpointing) # torch compile if cfg.use_torch_compile: if rank == 0: - print(f"--> enabling torch compile...") + print("--> enabling torch compile...") # the default accumulated_cache_size_limit=64 is not enough for 70b model, so we make it 128 here torch._dynamo.config.accumulated_cache_size_limit = 128 model = torch.compile(model) # Optimizer optimizer = optim.AdamW( - model.parameters(), lr=cfg.learning_rate, betas=(0.9, 0.95), weight_decay=0.1 + model.parameters(), + lr=cfg.learning_rate, + betas=(0.9, 0.95), + weight_decay=0.1, ) # optionally load from checkpoint (when continue pretraining) @@ -131,14 +202,21 @@ def main(**kwargs): g["initial_lr"] = cfg.learning_rate # LR schedule + warmup_interval = min(2000, cfg.num_steps // 20) + warmup = lambda x: 1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2 # linear decay for annealing if cfg.training_stage == "annealing": - schedule = lambda x: 1 - x / cfg.num_steps + schedule = lambda x: min( + warmup(x), + 1 - x / cfg.num_steps, + ) + elif cfg.training_stage == "constant": + # no decay for intermediate jobs + schedule = warmup else: # cosine decay - warmup_interval = min(2000, cfg.num_steps // 20) schedule = lambda x: min( - 1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2, + warmup(x), 0.1 + 0.5 * (1 - 0.1) @@ -165,6 +243,7 @@ def main(**kwargs): checkpointer, start_step, tokens_seen, + cp_degree, ) checkpointer.save_single_file(cfg.num_steps, model) From 1408cdd916491e817f53ec2c9d6b3383eb17145a Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 29 Apr 2025 16:01:53 -0400 Subject: [PATCH 03/18] Slice begin/end when long enough --- fms_fsdp/utils/dataset_utils.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index ff9f5c80..cbe99406 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -744,17 +744,27 @@ def __iter__(self): doclist.append(inp[last_delim:i]) last_delim = i+1 doclist.append(inp[last_delim:]) - nslice = int((len(doclist)-2)*self.slicerate) - if len(doclist) < 3 or nslice < 2: - yield inp - else: + # Pull out any short caps + if len(doclist[0])//3 <= self.overlap: begin = doclist[0] + doclist = doclist[1:] + if len(doclist[-1])//3 <= self.overlap: end = doclist[-1] - slice = doclist[1:1+nslice] - unslice = doclist[1+nslice:-1] + doclist = doclist[:-1] + # Figure out which docs to slice + slice = [] + unslice = [] + for doc in doclist: + if torch.rand(1, generator=self.generator) < self.slicerate and len(doc)//3 > self.overlap: + slice.append(doc) + else: + unslice.append(doc) + if len(slice) <= 1: + yield inp + else: + # Perform slicing sliced = [] for doc in slice: - assert len(doc)//3 > self.overlap, f"Doc length {len(doc)} too small for random slice with desired overlap {self.overlap}: {[len(begin), [len(x) for x in slice], [len(x) for x in unslice], len(end)]}" i = torch.randint(0, len(doc)//3, [1], generator=self.generator).item() + len(doc)//3 sliced.append([doc[:i], doc[i-self.overlap:]]) slice = sliced From 43884a825ecaa83ecb730904179e30ae001e2de6 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 29 Apr 2025 19:51:26 -0400 Subject: [PATCH 04/18] Empty begin/end --- fms_fsdp/utils/dataset_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index cbe99406..8461f9f2 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -745,6 +745,8 @@ def __iter__(self): last_delim = i+1 doclist.append(inp[last_delim:]) # Pull out any short caps + begin = [] + end = [] if len(doclist[0])//3 <= self.overlap: begin = doclist[0] doclist = doclist[1:] From f699bd5c3bb8f6fa34c045883b2d3c760d43c248 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 12 Jun 2025 17:38:42 -0400 Subject: [PATCH 05/18] Pull in cfg and loader updates from mamba-tiktoken --- fms_fsdp/config/training.py | 9 + fms_fsdp/utils/config_utils.py | 102 +-------- fms_fsdp/utils/dataloader_utils.py | 36 +++- fms_fsdp/utils/dataset_utils.py | 321 ++++++++++++++++++++++------- 4 files changed, 276 insertions(+), 192 deletions(-) diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py index 8bdda60c..45c4a87a 100644 --- a/fms_fsdp/config/training.py +++ b/fms_fsdp/config/training.py @@ -26,9 +26,18 @@ class train_config: strip_tokens: str = "" logical_shards: int = 1024 num_workers: int = 1 + doc_cutoff: int = 1_000_000 + doc_breakpoint: int = 65_536 filter_exp: int = 2 target_doclen: int = 8192 + # FIM training + psm_rate: float = 0.0 + spm_rate: float = 0.0 + fim_pre: int = 1 + fim_mid: int = 2 + fim_suf: int = 3 + # fsdp policies sharding_strategy: str = "hsdp" fsdp_activation_checkpointing: bool = False diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index 00e4d786..85d4253a 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -174,107 +174,7 @@ def get_model_config(model_variant): "num_heads_kv": 8, "out_proj_bias": False, "qkv_proj_bias": False, - "rotary_emb_dim": 64, - }, - "rms_norm": True, - "residual_in_fp32": True, - "fused_add_norm": True, - "pad_vocab_size_multiple": 16, - "tie_embeddings": False, - } - elif model_variant == "mamba_9.8b_8x": - model_config = { - "d_model": 4096, - "d_intermediate": 14336, - "n_layer": 32, - "vocab_size": 128256, - "ssm_cfg": {"layer": "Mamba2"}, - "attn_layer_idx": [9, 18, 27], - "attn_cfg": { - "causal": True, - "d_conv": 0, - "head_dim": 128, - "num_heads": 32, - "num_heads_kv": 8, - "out_proj_bias": False, - "qkv_proj_bias": False, - "rotary_emb_dim": 64, - "rotary_emb_base": 80000, - }, - "rms_norm": True, - "residual_in_fp32": True, - "fused_add_norm": True, - "pad_vocab_size_multiple": 16, - "tie_embeddings": False, - } - elif model_variant == "mamba_9.8b_16x": - model_config = { - "d_model": 4096, - "d_intermediate": 14336, - "n_layer": 32, - "vocab_size": 128256, - "ssm_cfg": {"layer": "Mamba2"}, - "attn_layer_idx": [9, 18, 27], - "attn_cfg": { - "causal": True, - "d_conv": 0, - "head_dim": 128, - "num_heads": 32, - "num_heads_kv": 8, - "out_proj_bias": False, - "qkv_proj_bias": False, - "rotary_emb_dim": 64, - "rotary_emb_base": 160000, - }, - "rms_norm": True, - "residual_in_fp32": True, - "fused_add_norm": True, - "pad_vocab_size_multiple": 16, - "tie_embeddings": False, - } - elif model_variant == "mamba_9.8b_32x": - model_config = { - "d_model": 4096, - "d_intermediate": 14336, - "n_layer": 32, - "vocab_size": 128256, - "ssm_cfg": {"layer": "Mamba2"}, - "attn_layer_idx": [9, 18, 27], - "attn_cfg": { - "causal": True, - "d_conv": 0, - "head_dim": 128, - "num_heads": 32, - "num_heads_kv": 8, - "out_proj_bias": False, - "qkv_proj_bias": False, - "rotary_emb_dim": 64, - "rotary_emb_base": 320000, - }, - "rms_norm": True, - "residual_in_fp32": True, - "fused_add_norm": True, - "pad_vocab_size_multiple": 16, - "tie_embeddings": False, - } - elif model_variant == "mamba_9.8b_500k": - model_config = { - "d_model": 4096, - "d_intermediate": 14336, - "n_layer": 32, - "vocab_size": 128256, - "ssm_cfg": {"layer": "Mamba2"}, - "attn_layer_idx": [9, 18, 27], - "attn_cfg": { - "causal": True, - "d_conv": 0, - "head_dim": 128, - "num_heads": 32, - "num_heads_kv": 8, - "out_proj_bias": False, - "qkv_proj_bias": False, - "rotary_emb_dim": 64, - "rotary_emb_base": 1280000, + "rotary_emb_dim": 0, }, "rms_norm": True, "residual_in_fp32": True, diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py index b954be9a..e9088a73 100644 --- a/fms_fsdp/utils/dataloader_utils.py +++ b/fms_fsdp/utils/dataloader_utils.py @@ -6,6 +6,7 @@ BufferDataset, CheckpointDataset, DocSliceDataset, + FIMDataset, ParquetHandler, PreloadBufferDataset, PreprocessDataset, @@ -13,6 +14,7 @@ ScalableShardDataset, StreamingDocDataset, ) +from math import ceil _handler_map = { @@ -58,9 +60,9 @@ def __iter__(self): return torch.utils.data.DataLoader(data, batch_size=cfg.batch_size) -def get_data_loader(cfg, rank, world_size, dp_degree, postprocess=[causal_lm]): +def get_data_loader(cfg, rank, world_size, dp_degree): """ - Pytorch dataloader for stateful, distributed, and rescalable causal language model (CLM) training. + Pytorch dataloader for stateful, distributed, and rescalable language model training. Assumes underlying data is sequences of integer values. ... Args @@ -71,9 +73,6 @@ def get_data_loader(cfg, rank, world_size, dp_degree, postprocess=[causal_lm]): Rank of current distributed worker. Used for handling dataset sharding logic. world_size : int Number of distributed workers. Used for handling dataset sharding logic. - postprocess : List[Callable] - Any task-specific postprocessing to apply before handing over data. Steps will apply in - the order provided by the user. For CLM training, use postprocess=[causal_lm]. """ do_cp = False @@ -84,6 +83,10 @@ def get_data_loader(cfg, rank, world_size, dp_degree, postprocess=[causal_lm]): world_size = dp_degree rank = rank // cp_worldsize + fim_training = cfg.psm_rate + cfg.spm_rate > 0 + if fim_training: + assert cfg.bos_token is None, "No BOS in FIM training. Did you mean fim_pre?" + datasets, weights, cols = parse_data_args(cfg.datasets, cfg.weights, cfg.col_name) # Base streaming dataset. Returns doc chunks in sequence. @@ -96,7 +99,7 @@ def get_data_loader(cfg, rank, world_size, dp_degree, postprocess=[causal_lm]): cfg.file_type in _handler_map ), f"File type {cfg.file_type} is not recognized ({list(_handler_map.keys())})" if cfg.file_type == "hf_parquet" or cfg.file_type == "auto": - filehandler = _handler_map[cfg.file_type](cfg.tokenizer_path, cols) + filehandler = _handler_map[cfg.file_type](cfg.tokenizer_path, cols, cfg.doc_cutoff) else: filehandler = _handler_map[cfg.file_type](cols) # Base reader layer @@ -111,6 +114,7 @@ def get_data_loader(cfg, rank, world_size, dp_degree, postprocess=[causal_lm]): min_length=cfg.target_doclen, seed=cfg.seed, filter_exp=cfg.filter_exp, + max_consecutive_chunks=ceil(cfg.doc_breakpoint/1024), ) # Add rescaling/resharding data = ScalableShardDataset( @@ -130,7 +134,7 @@ def get_data_loader(cfg, rank, world_size, dp_degree, postprocess=[causal_lm]): # Wrap above dataset in packing logic to form constant-length lines. data = BufferDataset( data, - cfg.seq_length if causal_lm not in postprocess else cfg.seq_length + 1, + cfg.seq_length + 1, bos_token=cfg.bol_token, eos_token=cfg.eol_token, pack_hard=True, @@ -144,11 +148,21 @@ def get_data_loader(cfg, rank, world_size, dp_degree, postprocess=[causal_lm]): slice_rate=.75, ) - # Apply desired postprocessing steps in sequence + # Apply FIM transformation if needed + if fim_training: + data = FIMDataset( + data, + cfg.eos_token, + cfg.psm_rate, + cfg.spm_rate, + pre_token=cfg.fim_pre, + mid_token=cfg.fim_mid, + suf_token=cfg.fim_suf, + ) + + # Transform to tensors data = PreprocessDataset(data, torch.IntTensor) - for p in postprocess: - data = PreprocessDataset(data, p) - + # Apply CP chunking if using CP if do_cp: def chunk(x): diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index 8461f9f2..73e8ae2a 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -180,16 +180,20 @@ def load_state_dict(self, state_dicts, sharded_input=False): self.load_worldsize = len(state_dicts) state_dicts = _shard_inclusive(state_dicts, self.rank, self.worldsize) if self.load_worldsize == self.worldsize: - [ - setattr(self, flag, state_dicts[0][self.statename(flag)]) - for flag in self.state_params + self.reshard_params - ] + for flag in self.state_params + self.reshard_params: + if self.statename(flag) in state_dicts[0]: + setattr(self, flag, state_dicts[0][self.statename(flag)]) + elif self.rank == 0: + logging.warning(f"Dataloader state key {self.statename(flag)} not present in checkpoint!") else: for flag in self.reshard_params: - reshard = self._reshard( - [sd[self.statename(flag)] for sd in state_dicts] - ) - setattr(self, flag, reshard) + if self.statename(flag) in state_dicts[0]: + reshard = self._reshard( + [sd[self.statename(flag)] for sd in state_dicts] + ) + setattr(self, flag, reshard) + elif self.rank == 0: + logging.warning(f"Dataloader state key {self.statename(flag)} not present in checkpoint!") return state_dicts def load_from_path(self, path: str): @@ -343,8 +347,8 @@ class ArrowHandler(_ShardFileHandler): Non-standard data format, though. """ - def __init__(self, col_name: List[str] = ["tokens"]): - self.col_name = col_name + def __init__(self, col_names: List[str] = ["text", "contents", "tokens"]): + self.col_names = col_names def is_legal(self, filepath: str): return "arrow" in os.path.splitext(filepath)[1] @@ -356,13 +360,18 @@ def length(self, path: str): return self.open(path).num_record_batches def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set): + assert ( + index < reader.num_record_batches + ), f"Illegal index {index} in set of {reader.num_record_batches} documents" frame = reader.get_batch(index) doc = None - for name in self.col_name: + for name in self.col_names: if name in frame.column_names: doc = frame[name] break - assert doc is not None, f"None of column names {self.col_name} found in file headers {frame.column_names}" + assert ( + doc is not None + ), f"None of column names {self.col_names} found in file headers {frame.column_names}" if len(doc) > 0 and doc[0].as_py() in drop_tokens: doc = doc.slice(1, len(doc) - 1) # Recheck len for edge case where doc=[eos] @@ -377,14 +386,15 @@ def slice(self, doc: pa.UInt32Array, index: int, n_pull: int) -> List: class ParquetHandler(_ShardFileHandler): """ Reader for indexable parquet shard files, common in HF datasets. - Here we assume reasonably small shard files (<5Gb) and documents (<100k tokens), + Here we assume reasonably small shard files (<5Gb) and truncate docs to max_doclen characters, as we rely on parquet/pandas for efficient file reading, and tokenize entire documents before getting/slicing. However, this is a standard and widely-used data format. """ - def __init__(self, tokenizer_path: str, col_name: List[str] = ["text"]): + def __init__(self, tokenizer_path: str, col_names: List[str] = ["text", "contents", "tokens"], max_doclen: int = 1_000_000): self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) - self.col_name = col_name + self.col_names = col_names + self.max_doclen = max_doclen def is_legal(self, filepath: str): return "parquet" in os.path.splitext(filepath)[1] @@ -392,18 +402,21 @@ def is_legal(self, filepath: str): def open(self, path: str): names = pq.read_metadata(path).schema.names match = None - for name in self.col_name: + for name in self.col_names: if name in names: match = name break - assert match is not None, f"None of column names {self.col_name} found in file headers {names}" + assert match is not None, f"None of column names {self.col_names} found in file headers {names}" return pq.read_pandas(path, columns=[match], partitioning=None)[match] def length(self, path: str): return pq.read_metadata(path).num_rows def get(self, reader, index: int, drop_tokens: Set): - doc = self.tokenizer(str(reader[index]))["input_ids"] + assert ( + index < reader.length() + ), f"Illegal index {index} in set of {reader.length()} documents" + doc = self.tokenizer(str(reader[index])[: self.max_doclen])["input_ids"] if len(doc) > 0 and doc[0] in drop_tokens: doc = doc[1:] # Recheck len for edge case where doc=[eos] @@ -416,9 +429,9 @@ def slice(self, doc: List, index: int, n_pull: int) -> List: class AutoHandler(_ShardFileHandler): - def __init__(self, tokenizer_path: str, col_name: List[str] = ["text", "contents", "tokens"]): - self.PHandler = ParquetHandler(tokenizer_path, col_name) - self.AHandler = ArrowHandler(col_name) + def __init__(self, tokenizer_path: str, col_names: List[str] = ["text", "contents", "tokens"], max_doclen: int = 1_000_000): + self.PHandler = ParquetHandler(tokenizer_path, col_names, max_doclen) + self.AHandler = ArrowHandler(col_names) self.current = _ShardFileHandler() def is_legal(self, filepath: str): @@ -800,6 +813,128 @@ def load_state_dict(self, state_dicts, sharded_input=False): return sharded_dicts +class FIMDataset(_WrapperDataset): + """ + Wrapper for a StatefulDataset that implements Fill-In-the-Middle training + (https://arxiv.org/pdf/2207.14255). + Input should be a packed sequence (i.e. call BufferDataset before FIMDataset). + Breaks sequence apart into component document spans, and for each document span + of sufficient length, transforms with specified probability into: + PSM mode:
 (prefix)  (suffix)  (middle) 
+    SPM mode: 
  (suffix)  (prefix) (middle) 
+    The new delimiter tokens can be omitted by passing in None.
+    Any extra tokens after transformation are dropped from the end of the sequence.
+    ...
+    Args
+    ----
+    dataset : _StatefulDataset
+        Fully instantiated dataset
+    delimiter_token : any
+        Token used to indicate document boundaries
+    psm_rate : float
+        Chance to transform into PSM. Cannot exceed 1.
+    spm_rate : float
+        Chance to transform into SPM. Cannot exceed 1.
+    min_len : int
+        Minimum document length to perform FIM transformation
+    pre_token : any | none
+        Token used to indicate prefix section of the document
+    mid_token : any | none
+        Token used to indicate middle infill section of the document
+    suf_token : any | none
+        Token used to indicate suffix section of the document
+    """
+
+    def __init__(
+        self,
+        dataset: _StatefulDataset,
+        delimiter_token: Any,
+        psm_rate: float = 0.0,
+        spm_rate: float = 0.0,
+        min_len: int = 10,
+        pre_token=None,
+        mid_token=None,
+        suf_token=None,
+    ):
+        super().__init__(dataset)
+        assert (
+            psm_rate + spm_rate > 0
+        ), f"FIM training requires SPM or PSM transformation. Please specify a nonzero psm_rate or spm_rate."
+        assert (
+            psm_rate + spm_rate <= 1
+        ), f"Combined psm_rate {psm_rate} and spm_rate {spm_rate} probabilities cannot exceed 1."
+        self.psm = psm_rate
+        self.spm = spm_rate
+        self.delimiter = delimiter_token
+        self.min_len = min_len
+        self.pref = pre_token
+        self.suff = suf_token
+        self.midd = mid_token
+
+        self.g_state = None
+        self.generator = torch.Generator().manual_seed(self.rank)
+        self.state_params = ["g_state"]
+
+    def __iter__(self):
+        dataset = iter(self.dataset)
+        while True:
+            inp = next(dataset)
+            len_ = len(inp)
+            i_eos = [0] + [i for i, x in enumerate(inp) if x == self.delimiter] + [len_]
+            docs = [
+                inp[i_eos[j] + 1 : i_eos[j + 1]] for j in range(len(i_eos) - 1)
+            ]  # list[list[any]]
+            out = []
+            for i in range(len(docs)):
+                doc = docs[i]
+                if len(docs[i]) >= self.min_len:
+                    # decide psm, spm, or nothing
+                    thresh = torch.rand([1], generator=self.generator).item()
+                    if thresh < self.psm + self.spm:
+                        # Split doc
+                        doc = []
+                        if self.pref:
+                            doc = [self.pref]
+                        splits = torch.randint(
+                            0, len(docs[i]), [2], generator=self.generator
+                        ).tolist()
+                        pre = docs[i][: min(splits)]
+                        mid = docs[i][min(splits) : max(splits)]
+                        suf = docs[i][max(splits) :]
+
+                        if thresh < self.psm:
+                            # PSM transformation
+                            doc += pre
+                            if self.suff:
+                                doc.append(self.suff)
+                            doc += suf
+                            if self.midd:
+                                doc.append(self.midd)
+                            doc += mid
+                        else:
+                            # SPM transformation
+                            if self.suff:
+                                doc.append(self.suff)
+                            doc += suf
+                            if self.midd:
+                                doc.append(self.midd)
+                            doc += pre + mid
+                out += doc + [self.delimiter]
+            yield out[:len_]
+
+    def state_dict(self):
+        # Write generator state manually
+        self.g_state = self.generator.get_state()
+        return super().state_dict()
+
+    def load_state_dict(self, state_dicts, sharded_input=False):
+        sharded_dicts = super().load_state_dict(state_dicts, sharded_input)
+        # Manually set generator state if it exists
+        if self.g_state is not None:
+            self.generator.set_state(self.g_state)
+        return sharded_dicts
+
+
 class BufferDataset(_WrapperDataset):
     """
     Wrapper for a _StatefulDataset that takes in sequences of varying lengths, and packs/pads them
@@ -945,10 +1080,10 @@ class StreamingDocDataset(_StatefulDataset):
         Documents below this length are skipped
     max_chunksize : int
         Maximum sequence length to return. Break long docs into chunks of this size or shorter.
+    max_consecutive_chunks : int
+        Number of doc chunks to emit before manually inserting EOS and resuming later.
     verbose : bool
         Track setup progress?
-    shuffle : bool
-        Shuffle shard file and document orders? (Disable for simple testing)
     """
 
     def __init__(
@@ -963,6 +1098,7 @@ def __init__(
         seed: int = 42,
         min_length: int = 1,
         max_chunksize: int = 1024,
+        max_consecutive_chunks: int = 256,
         verbose: bool = False,
         filter_exp: int = 2,
     ):
@@ -976,6 +1112,7 @@ def __init__(
         self.eos = delimiter_token
         self.bos = bos_token
         self.drop = strip_tokens
+        self.max_consec = max_consecutive_chunks
         self.verbose = verbose
         self.filter_exp = filter_exp
         self.docset: List[
@@ -992,6 +1129,7 @@ def __init__(
         self.docs_seen = 0
         self.percent_seen = 0
         self.has_yielded = False
+        self.consec = 0
 
         self.state_params = [
             "dataset",
@@ -1003,6 +1141,7 @@ def __init__(
             "percent_seen",
             "lcg_state",
             "g_state",
+            "consec",
         ]
 
         # Setup flags
@@ -1033,35 +1172,62 @@ def setup(self):
             # listdir, assemble shardfraglist (ind -> shard, frag)
             shards = [
                 os.path.join(root, name)[len(datapath) + 1 :]
-                for root, dirs, files in os.walk(datapath, topdown=False)
+                for root, dirs, files in os.walk(datapath, topdown=False, followlinks=True)
                 for name in files
                 if self.filehandler.is_legal(os.path.join(root, name))
                 and os.path.getsize(os.path.join(root, name)) > 1_000_000
                 # 1mb minimum file size to prevent empty files
             ]
             shards.sort()  # Ensure consistent sharding across machines
-            start_frag = (self.rank * self.worldsize * len(shards)) // self.worldsize
-            end_frag = (
-                (self.rank + 1) * self.worldsize * len(shards)
-            ) // self.worldsize
-            shardfrags = [
-                (shards[i // self.worldsize], i % self.worldsize)
-                for i in range(start_frag, end_frag)
-            ]
-
-            # Assemble length of each owned shard file
 
+            # Find metadata file
             countfiles = []
-            # if os.path.exists(os.path.join(pardir, "meta")):
-            #     countfiles = [
-            #         x
-            #         for x in os.listdir(os.path.join(pardir, "meta"))
-            #         if "counts" in x and "csv" in x
-            #     ]
-            doc_counts = {}
+            if os.path.exists(os.path.join(pardir, "meta")):
+                countfiles = [
+                    x
+                    for x in os.listdir(os.path.join(pardir, "meta"))
+                    if "counts" in x and "csv" in x
+                ]
             if len(countfiles) > 0:
                 # Count file exists, use it
                 countpath = os.path.join(pardir, "meta", countfiles[0])
+            else:
+                countpath = ""
+
+            # Use shard file sizes to perform partitioning
+            # Create shardlist of form shardid -> [start%, end%]
+            if len(countfiles) > 0:
+                sizes = {}
+                with open(countpath, "r") as csvfile:
+                    reader = csv.DictReader(csvfile)
+                    for row in reader:
+                        fullpath = row["dataset/filename"]
+                        prefix = fullpath.find(dataset + "/")
+                        if prefix >= 0:
+                            key = fullpath[prefix + len(dataset) + 1 :]
+                            sizes[key] = int(row["size"])
+                shard_sizes = [sizes[shard] for shard in shards]
+            else:
+                shard_sizes = [
+                    os.path.getsize(os.path.join(datapath, shard)) for shard in shards
+                ]
+            shard_sizes = [s / sum(shard_sizes) for s in shard_sizes]
+            start = self.rank / self.worldsize
+            end = (self.rank + 1) / self.worldsize
+            shardset = {}
+            tally = 0
+            for i in range(len(shards)):
+                if tally <= end and tally + shard_sizes[i] >= start:
+                    shardset[shards[i]] = [
+                        min(max((start - tally) / shard_sizes[i], 0), 1),
+                        min(max((end - tally) / shard_sizes[i], 0), 1),
+                    ]
+                tally += shard_sizes[i]
+
+            # Assemble length of each owned shard file
+            doc_counts = {}
+            if len(countfiles) > 0:
+                # Count file exists, use it
                 with open(countpath, "r") as csvfile:
                     reader = csv.DictReader(csvfile)
                     for row in reader:
@@ -1072,41 +1238,28 @@ def setup(self):
                             doc_counts[key] = int(row["documents"])
             else:
                 # Count file does not exist, touch every owned file for length
-                unique_shardfiles = set(shard for shard, frag in shardfrags)
                 doc_counts = {
                     shard: self.filehandler.length(os.path.join(datapath, shard))
-                    for shard in unique_shardfiles
+                    for shard in shardset
                 }
 
-            # Read shardfrags, assemble doc list for each file shard (aggregating over fragments):
-            ndocs = -1
-            docset = {}  # shardid -> (min docid, max docid)
-            for i, (shard, frag) in enumerate(shardfrags):
-                ndocs = doc_counts[shard]
-                doc_start = (ndocs * frag) // self.worldsize
-                doc_end = (
-                    ndocs * frag + ndocs
-                ) // self.worldsize - 1  # Inclusive upper bound
-                if shard not in docset:
-                    docset[shard] = [doc_start, doc_end]
-                min_d, max_d = docset[shard]
-                if doc_start < min_d:
-                    docset[shard][0] = doc_start
-                if doc_end > max_d:
-                    docset[shard][1] = doc_end
-
-            # Add shard entries to self.docset
+            # Assemble doc list for each file shard
+            # Create docset of form [shardid, min docid, max docid]
             doccount = 0
-            for shardid in docset:
-                min_d = docset[shardid][0]
-                max_d = docset[shardid][1]
-                self.docset.append((shardid, min_d, max_d))
-                doccount += max_d - min_d + 1
+            for shard in shardset:
+                ndocs = doc_counts[shard]
+                if ndocs > 0:
+                    doc_start = int(ndocs * shardset[shard][0])
+                    doc_end = max(
+                        doc_start, int(ndocs * shardset[shard][1]) - 1
+                    )  # inclusive upper bound
+                    self.docset.append([shard, doc_start, doc_end])
+                    doccount += doc_end - doc_start + 1
             self._len = doccount
 
             if self.verbose:
                 logging.info(
-                    f"    Worker {self.rank} ingested {len(shardfrags)} shard fragments from {dataset}"
+                    f"    Worker {self.rank} ingested {len(self.docset)} shard fragments from {dataset}"
                 )
 
             # Shuffle shard files - guaranteed inconsistent across workers
@@ -1162,8 +1315,11 @@ def _construct_chunk(self, j, doc, n_chunks):
         # Add bos/eos tokens if needed
         if self.bos is not None and j == 0:
             chunk = [self.bos] + chunk
-        if j == n_chunks - 1:
+        if j == n_chunks - 1 or self.consec == self.max_consec:
             chunk = chunk + [self.eos]
+            self.consec = 0
+        else:
+            self.consec += 1
         return chunk
 
     def _random_map_docid(self, size):
@@ -1208,11 +1364,9 @@ def __iter__(self):
                 doclcg = self._random_map_docid(docrange)
                 docid = doclcg + mindoc
                 doc = self.filehandler.get(reader, docid, self.drop)
-                if len(doc) == 0:
-                    continue
                 doclen = len(doc) + 1 if self.bos is None else len(doc) + 2
                 keep_chance = (doclen/self.min_length)**self.filter_exp
-                if torch.rand(1, generator=self.g).item() < keep_chance:
+                if len(doc) > 0 and torch.rand(1, generator=self.g).item() < keep_chance:
                     n_chunks = math.ceil(doclen / self.chunksize)
                     for j in range(n_chunks):
                         if i == 0 and j < residual_chunks:
@@ -1239,11 +1393,9 @@ def __iter__(self):
             newpath = os.path.join(self.datapath, shardid)
             path, reader = self._get_reader(path, newpath, reader)
             doc = self.filehandler.get(reader, docid, self.drop)
-            if len(doc) == 0:
-                continue
             doclen = len(doc) + 1 if self.bos is None else len(doc) + 2
             keep_chance = (doclen/self.min_length)**self.filter_exp
-            if torch.rand(1, generator=self.g).item() < keep_chance:
+            if len(doc) > 0 and torch.rand(1, generator=self.g).item() < keep_chance:
                 n_chunks = math.ceil(doclen / self.chunksize)
                 for j in range(residual_chunks):
                     self.chunk_index = j
@@ -1333,12 +1485,12 @@ def setup(self):
         if not self.is_setup:
             _StatefulDataset.setup(self)
             n_logical_shards = self.total_shards
+            assert (
+                n_logical_shards % self.worldsize == 0
+            ), f"Total workers {self.worldsize} must divide n_logical_shards {n_logical_shards} evenly"
             logicals = list(range(n_logical_shards))
             self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize)
             self.n_logicals = n_logical_shards // self.worldsize
-            assert (
-                len(self.logicals_owned) == self.n_logicals
-            ), "(world size * num workers) does not divide logical shards evenly"
 
             # Build logical shards
             for i in range(self.n_logicals):
@@ -1355,6 +1507,9 @@ def setup(self):
                     )
             [d.setup() for d in self.data]
             self.n_docs_remaining = [d._len for d in self.data]
+            assert (
+                sum(self.n_docs_remaining) > 0
+            ), f"No documents detected in shard {self.rank} of {self.datapath}"
 
             self.generator = torch.Generator().manual_seed(self.rank)
 
@@ -1362,14 +1517,16 @@ def __iter__(self):
         self.setup()
         # Grab one doc at a time in random order
         data = [iter(d) for d in self.data]
+        # Reset if we're rescaling into a prematurely finished epoch
+        # (i.e. [1,1,0,0,0,0] into [1,1,0] [0,0,0] )
+        if sum(self.n_docs_remaining) == 0:
+            self.n_docs_remaining = [d._len for d in self.data]
+            self.generator.manual_seed(self.rank)
         while True:
             # Sample logical shard (or load from ckp)
             if self.current_reader is not None:
                 ind = self.current_reader
             else:
-                assert (
-                    sum(self.n_docs_remaining) > 0
-                ), f"No documents detected in {self.datapath}"
                 ind = torch.multinomial(
                     torch.tensor(self.n_docs_remaining, dtype=torch.float),
                     1,
@@ -1461,6 +1618,10 @@ def __init__(
             ]
         )
         assert len(self.datasets) > 0, "You must specify at least one dataset"
+        for d in datasets:
+            assert os.path.exists(
+                os.path.join(datapath, d)
+            ), f"Invalid subdataset path: {os.path.join(datapath, d)}"
 
         if weights is not None:
             assert len(weights) == len(

From 362bcd440ec9a54fae8e4c5cbd11b16fa9a2b35c Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Thu, 12 Jun 2025 17:46:26 -0400
Subject: [PATCH 06/18] Readd clm

---
 fms_fsdp/utils/dataloader_utils.py | 7 +++----
 fms_fsdp/utils/dataset_utils.py    | 5 ++---
 2 files changed, 5 insertions(+), 7 deletions(-)

diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py
index e9088a73..6c657ea5 100644
--- a/fms_fsdp/utils/dataloader_utils.py
+++ b/fms_fsdp/utils/dataloader_utils.py
@@ -147,8 +147,7 @@ def get_data_loader(cfg, rank, world_size, dp_degree):
         cfg.eos_token,
         slice_rate=.75,
     )
-
-        # Apply FIM transformation if needed
+    # Apply FIM transformation if needed
     if fim_training:
         data = FIMDataset(
             data,
@@ -159,10 +158,10 @@ def get_data_loader(cfg, rank, world_size, dp_degree):
             mid_token=cfg.fim_mid,
             suf_token=cfg.fim_suf,
         )
-
     # Transform to tensors
     data = PreprocessDataset(data, torch.IntTensor)
-    
+    # Apply CLM transformation
+    data = PreprocessDataset(data, causal_lm)
     # Apply CP chunking if using CP
     if do_cp:
         def chunk(x):
diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py
index 73e8ae2a..c32ec705 100644
--- a/fms_fsdp/utils/dataset_utils.py
+++ b/fms_fsdp/utils/dataset_utils.py
@@ -675,10 +675,9 @@ def __init__(self, dataset: _StatefulDataset, window_size: int):
 
     def __iter__(self):
         dataset = iter(self.dataset)
+        # Pad out buffer if needed
+        self._pad_buffer()
         while True:
-            # Pad out buffer if needed
-            self._pad_buffer()
-
             # If buffer is undersized, add a datapoint
             if self.buffer_size < self.window_size:
                 self.buffer[self.buffer_size] = next(dataset)

From 8d0f4f2f8d8bb1cb1b20203870751c84424cbbd4 Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Fri, 13 Jun 2025 14:41:00 -0400
Subject: [PATCH 07/18] Port train_utils

---
 fms_fsdp/utils/train_utils.py | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

diff --git a/fms_fsdp/utils/train_utils.py b/fms_fsdp/utils/train_utils.py
index bbd8d54b..ba05636d 100644
--- a/fms_fsdp/utils/train_utils.py
+++ b/fms_fsdp/utils/train_utils.py
@@ -44,19 +44,22 @@ def train(
             except ImportError:
                 raise ImportError("tracker is set to wandb but wandb is not installed.")
             if rank == 0:
-                print("--> wandb is enabled!")
+                print("--> Started initializing wandb", flush=True)
                 try:
                     wandb.init(
                         project=project_name,
                         dir=tracker_dir,
                         resume="allow",
                         id=run_id,
+                        # mode='offline',
+                        settings=wandb.Settings(init_timeout=3600),
                     )
                 except wandb.errors.UsageError:
                     raise ValueError(
                         "wandb failed to init, did you pass your wandb api key via WANDB_API_KEY?"
                     )
                 wandb.config = asdict(cfg)
+                print(f"--> wandb is enabled!", flush=True)
 
         if cfg.tracker == "aim":
             try:
@@ -102,7 +105,7 @@ def train(
         if profiler:
             profiler.step()
 
-        if batch_idx % cfg.report_interval == 0:
+        if batch_idx % cfg.report_interval == 0 or batch_idx == start_step + 1:
             dist.all_reduce(ddp_stats, op=dist.ReduceOp.SUM)
             train_loss = ddp_stats[0] / ddp_stats[2]
             g_norm = ddp_stats[1] / ddp_stats[2]
@@ -151,7 +154,7 @@ def train(
                     int(new_tokens_seen / elapsed_time * 3600 * 24),
                 )
                 print(f"Total tok/step: {world_size * cfg.batch_size * cfg.seq_length}")
-                if cfg.tracker:
+                if cfg.tracker and batch_idx > start_step + 1:
                     vals_to_track = {
                         "learning rate": current_lr,
                         "loss": current_loss,

From a24f24bface72d717ad3a654a5f062b297420fa8 Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Fri, 13 Jun 2025 14:56:17 -0400
Subject: [PATCH 08/18] orig params

---
 main_training_mamba.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/main_training_mamba.py b/main_training_mamba.py
index abab8ba0..5c1ff3da 100644
--- a/main_training_mamba.py
+++ b/main_training_mamba.py
@@ -152,7 +152,7 @@ def lambda_fn(module: nn.Module):
         auto_wrap_policy=wrapping_policy,
         mixed_precision=mixed_precision_policy,
         sharding_strategy=sharding_strategy_policy,
-        use_orig_params=cfg.use_torch_compile,
+        use_orig_params=True,
         device_id=torch.cuda.current_device(),
         limit_all_gathers=True,
         param_init_fn=param_init_fn,

From 710a275c0fec42ea38921a6d1a1f426d323cc558 Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Fri, 13 Jun 2025 15:02:27 -0400
Subject: [PATCH 09/18] Separated weights

---
 main_training_mamba.py | 30 ++++++++++++++++++++++++++----
 1 file changed, 26 insertions(+), 4 deletions(-)

diff --git a/main_training_mamba.py b/main_training_mamba.py
index 5c1ff3da..cbd60890 100644
--- a/main_training_mamba.py
+++ b/main_training_mamba.py
@@ -175,11 +175,33 @@ def lambda_fn(module: nn.Module):
         model = torch.compile(model)
 
     # Optimizer
+    # optimizer = optim.AdamW(
+    #     model.parameters(),
+    #     lr=cfg.learning_rate,
+    #     betas=(0.9, 0.95),
+    #     weight_decay=0.1,
+    # )
+    params_with_decay = []
+    params_without_decay = []
+    for name, param in model.named_parameters():
+        suff = name.split('.')[-1]
+        if 'A_log' in suff or 'D' in suff or 'dt_bias' in suff:
+            params_without_decay.append(param)
+        else:
+            params_with_decay.append(param)
     optimizer = optim.AdamW(
-        model.parameters(),
-        lr=cfg.learning_rate,
-        betas=(0.9, 0.95),
-        weight_decay=0.1,
+        [
+            {
+                "params": params_with_decay,
+                "weight_decay": 0.1,
+            },
+            {
+                "params": params_without_decay,
+                "weight_decay": 0.,
+            },
+        ],
+        betas = (0.9, 0.95),
+        lr = cfg.learning_rate,
     )
 
     # optionally load from checkpoint (when continue pretraining)

From 0a7ea94d9c03f1404265cd5257293e4a09b1bbbe Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Fri, 13 Jun 2025 15:13:58 -0400
Subject: [PATCH 10/18] Full changes ported (minus conversion)

---
 main_training_mamba.py | 26 ++++++++++++++++++--------
 1 file changed, 18 insertions(+), 8 deletions(-)

diff --git a/main_training_mamba.py b/main_training_mamba.py
index cbd60890..88fb9bbc 100644
--- a/main_training_mamba.py
+++ b/main_training_mamba.py
@@ -228,14 +228,9 @@ def lambda_fn(module: nn.Module):
     warmup = lambda x: 1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2
     # linear decay for annealing
     if cfg.training_stage == "annealing":
-        schedule = lambda x: min(
-            warmup(x),
-            1 - x / cfg.num_steps,
-        )
-    elif cfg.training_stage == "constant":
-        # no decay for intermediate jobs
-        schedule = warmup
-    else:
+        warmup_interval = 1000
+        schedule = lambda x: x / warmup_interval if x < warmup_interval else 1 - (x - warmup_interval) / (cfg.num_steps - warmup_interval)
+    elif cfg.training_stage == "cosine":
         # cosine decay
         schedule = lambda x: min(
             warmup(x),
@@ -244,6 +239,21 @@ def lambda_fn(module: nn.Module):
             * (1 - 0.1)
             * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)),
         )
+    elif cfg.training_stage == "constant":
+        warmup_interval = 2000
+        schedule = lambda x: (min(x, warmup_interval) / warmup_interval)
+    elif cfg.training_stage == "linear_to_constant":
+        linear_steps = 25000
+        start_lr = 2e-4
+        end_lr = 2e-4
+        schedule = lambda x: (start_lr + (end_lr - start_lr) * min(x - start_step, linear_steps) / linear_steps) / cfg.learning_rate
+    elif cfg.training_stage == "annealing_with_specified_decay_steps":
+        warmup_interval = 2000
+        total_decay_steps = 25000
+        schedule = lambda x: (x - start_step) / warmup_interval if x - start_step < warmup_interval else max(0.0, 1 - (x - start_step - warmup_interval) / total_decay_steps)
+    else:
+        schedule = lambda x: 1.0 + (0.75 - 1.0) * (x / 32000) if x <= 32000 else 0.75
+        
 
     scheduler = LambdaLR(optimizer, lambda x: schedule(x + start_step))
 

From a9e16a422437896abd0b8111754753d29a1d96b4 Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Fri, 13 Jun 2025 16:09:27 -0400
Subject: [PATCH 11/18] Pad buffer after saving ckpt

---
 fms_fsdp/utils/dataset_utils.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py
index c32ec705..1fd36c77 100644
--- a/fms_fsdp/utils/dataset_utils.py
+++ b/fms_fsdp/utils/dataset_utils.py
@@ -707,6 +707,8 @@ def state_dict(self):
         # Prune buffer so it can be resharded in future
         self.buffer = self.buffer[: self.buffer_size]
         out = super().state_dict()
+        # Pad buffer back out again
+        self._pad_buffer()
         return out
 
     def load_state_dict(self, state_dicts, sharded_input=False):

From 62f7e175864e3ea5ccc858d549da4c334003dd46 Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Fri, 13 Jun 2025 16:37:02 -0400
Subject: [PATCH 12/18] Temp disable optim load for diag porpoises

---
 main_training_mamba.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/main_training_mamba.py b/main_training_mamba.py
index 88fb9bbc..fe9ce4f7 100644
--- a/main_training_mamba.py
+++ b/main_training_mamba.py
@@ -208,9 +208,9 @@ def lambda_fn(module: nn.Module):
     checkpointer = Checkpointer(
         cfg.ckpt_save_path, 1000, cfg.sharding_strategy, rank, local_rank
     )
-    model, optimizer, _, start_step, tokens_seen, is_resuming = checkpointer.load(
+    model, _, _, start_step, tokens_seen, is_resuming = checkpointer.load(
         model,
-        optimizer,
+        None,
         None,
         path=os.path.join(cfg.ckpt_load_path, "checkpoints/")
         if not os.path.isfile(cfg.ckpt_load_path)

From 3bae2f6b81159fab50129b95bfe6793bf0301204 Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Fri, 13 Jun 2025 16:48:40 -0400
Subject: [PATCH 13/18] Revert optim skip

---
 main_training_mamba.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/main_training_mamba.py b/main_training_mamba.py
index fe9ce4f7..88fb9bbc 100644
--- a/main_training_mamba.py
+++ b/main_training_mamba.py
@@ -208,9 +208,9 @@ def lambda_fn(module: nn.Module):
     checkpointer = Checkpointer(
         cfg.ckpt_save_path, 1000, cfg.sharding_strategy, rank, local_rank
     )
-    model, _, _, start_step, tokens_seen, is_resuming = checkpointer.load(
+    model, optimizer, _, start_step, tokens_seen, is_resuming = checkpointer.load(
         model,
-        None,
+        optimizer,
         None,
         path=os.path.join(cfg.ckpt_load_path, "checkpoints/")
         if not os.path.isfile(cfg.ckpt_load_path)

From b67fb6ea06c5d4eeeb5f14ec3df918a50ae5ea8d Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Fri, 13 Jun 2025 16:56:36 -0400
Subject: [PATCH 14/18] Account for no delims in docslice layer

---
 fms_fsdp/utils/dataset_utils.py | 89 +++++++++++++++++----------------
 1 file changed, 46 insertions(+), 43 deletions(-)

diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py
index 1fd36c77..2f8645f1 100644
--- a/fms_fsdp/utils/dataset_utils.py
+++ b/fms_fsdp/utils/dataset_utils.py
@@ -753,52 +753,55 @@ def __iter__(self):
             inplen = len(inp)
             doclist = []
             last_delim = 0
-            for i in range(len(inp)):
-                if inp[i] == self.delimiter:
-                    doclist.append(inp[last_delim:i])
-                    last_delim = i+1
-            doclist.append(inp[last_delim:])
-            # Pull out any short caps
-            begin = []
-            end = []
-            if len(doclist[0])//3 <= self.overlap:
-                begin = doclist[0]
-                doclist = doclist[1:]
-            if len(doclist[-1])//3 <= self.overlap:
-                end = doclist[-1]
-                doclist = doclist[:-1]
-            # Figure out which docs to slice
-            slice = []
-            unslice = []
-            for doc in doclist:
-                if torch.rand(1, generator=self.generator) < self.slicerate and len(doc)//3 > self.overlap:
-                    slice.append(doc)
-                else:
-                    unslice.append(doc)
-            if len(slice) <= 1:
+            if self.delimiter not in inp:
                 yield inp
             else:
-                # Perform slicing
-                sliced = []
-                for doc in slice:
-                    i = torch.randint(0, len(doc)//3, [1], generator=self.generator).item() + len(doc)//3
-                    sliced.append([doc[:i], doc[i-self.overlap:]])
-                slice = sliced
-                doclist = [slice[0][0], slice[1][0], slice[0][1], slice[1][1]]
-                for docpair in slice[2:]:
-                    inds = torch.randperm(len(doclist)+1, generator=self.generator)[:2].tolist()
-                    inds.sort()
-                    inds[1] += 1
-                    doclist = doclist[:inds[0]] + [docpair[0]] + doclist[inds[0]:inds[1]-1] + [docpair[1]] + doclist[inds[1]-1:]
-                for doc in unslice:
-                    i = torch.randint(0, len(doclist)+1, [1], generator=self.generator).item()
-                    doclist = doclist[:i] + [doc] + doclist[i:]
-                out = begin + [self.delimiter]
+                for i in range(len(inp)):
+                    if inp[i] == self.delimiter:
+                        doclist.append(inp[last_delim:i])
+                        last_delim = i+1
+                doclist.append(inp[last_delim:])
+                # Pull out any short caps
+                begin = []
+                end = []
+                if len(doclist[0])//3 <= self.overlap:
+                    begin = doclist[0]
+                    doclist = doclist[1:]
+                if len(doclist[-1])//3 <= self.overlap:
+                    end = doclist[-1]
+                    doclist = doclist[:-1]
+                # Figure out which docs to slice
+                slice = []
+                unslice = []
                 for doc in doclist:
-                    out = out + doc
-                    out.append(self.delimiter)
-                out = out + end
-                yield out[:inplen]
+                    if torch.rand(1, generator=self.generator) < self.slicerate and len(doc)//3 > self.overlap:
+                        slice.append(doc)
+                    else:
+                        unslice.append(doc)
+                if len(slice) <= 1:
+                    yield inp
+                else:
+                    # Perform slicing
+                    sliced = []
+                    for doc in slice:
+                        i = torch.randint(0, len(doc)//3, [1], generator=self.generator).item() + len(doc)//3
+                        sliced.append([doc[:i], doc[i-self.overlap:]])
+                    slice = sliced
+                    doclist = [slice[0][0], slice[1][0], slice[0][1], slice[1][1]]
+                    for docpair in slice[2:]:
+                        inds = torch.randperm(len(doclist)+1, generator=self.generator)[:2].tolist()
+                        inds.sort()
+                        inds[1] += 1
+                        doclist = doclist[:inds[0]] + [docpair[0]] + doclist[inds[0]:inds[1]-1] + [docpair[1]] + doclist[inds[1]-1:]
+                    for doc in unslice:
+                        i = torch.randint(0, len(doclist)+1, [1], generator=self.generator).item()
+                        doclist = doclist[:i] + [doc] + doclist[i:]
+                    out = begin + [self.delimiter]
+                    for doc in doclist:
+                        out = out + doc
+                        out.append(self.delimiter)
+                    out = out + end
+                    yield out[:inplen]
 
     def state_dict(self):
         # Write generator state manually

From 4839584562cfe5916180d9acd9343c97719b5ad4 Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Fri, 13 Jun 2025 17:55:59 -0400
Subject: [PATCH 15/18] Diag print

---
 fms_fsdp/utils/train_utils.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/fms_fsdp/utils/train_utils.py b/fms_fsdp/utils/train_utils.py
index ba05636d..4dff902c 100644
--- a/fms_fsdp/utils/train_utils.py
+++ b/fms_fsdp/utils/train_utils.py
@@ -84,6 +84,8 @@ def train(
     for batch_idx, (input, label) in enumerate(train_loader, start=start_step + 1):
         if batch_idx > cfg.num_steps:
             break
+        if rank == 0:
+            print(input.shape)
         input = input.to(local_rank)
         label = label.to(local_rank)
 

From 4c1dcf9bd5b4866173b1e45f9fb5d30f241de4fb Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Fri, 13 Jun 2025 18:10:37 -0400
Subject: [PATCH 16/18] Wipe cache if wrong len

---
 fms_fsdp/utils/dataset_utils.py | 13 ++++++++++---
 fms_fsdp/utils/train_utils.py   |  2 --
 2 files changed, 10 insertions(+), 5 deletions(-)

diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py
index 2f8645f1..f2d7823f 100644
--- a/fms_fsdp/utils/dataset_utils.py
+++ b/fms_fsdp/utils/dataset_utils.py
@@ -677,10 +677,17 @@ def __iter__(self):
         dataset = iter(self.dataset)
         # Pad out buffer if needed
         self._pad_buffer()
+        first_draw = next(dataset)
         while True:
+            # If buffer entries have wrong length, reset buffer
+            if len(first_draw) != len(self.buffer[0]):
+                self.buffer = []
+                self.buffer_size = 0
+                self._pad_buffer()
+
             # If buffer is undersized, add a datapoint
             if self.buffer_size < self.window_size:
-                self.buffer[self.buffer_size] = next(dataset)
+                self.buffer[self.buffer_size] = next(dataset) if self.buffer_size > 0 else first_draw
                 self.buffer_size += 1
 
             # Swap out randomly sampled value from buffer.
@@ -696,10 +703,10 @@ def __iter__(self):
             yield out
 
     def _pad_buffer(self):
-        if self.buffer_size < self.window_size:
+        if len(self.buffer) < self.window_size:
             self.buffer += [
                 [],
-            ] * (self.window_size - self.buffer_size)
+            ] * (len(self.buffer) - self.buffer_size)
 
     def state_dict(self):
         # Write generator state manually
diff --git a/fms_fsdp/utils/train_utils.py b/fms_fsdp/utils/train_utils.py
index 4dff902c..ba05636d 100644
--- a/fms_fsdp/utils/train_utils.py
+++ b/fms_fsdp/utils/train_utils.py
@@ -84,8 +84,6 @@ def train(
     for batch_idx, (input, label) in enumerate(train_loader, start=start_step + 1):
         if batch_idx > cfg.num_steps:
             break
-        if rank == 0:
-            print(input.shape)
         input = input.to(local_rank)
         label = label.to(local_rank)
 

From abd5ef19aa358905891b80da5f3c0f5b927a44a5 Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Fri, 13 Jun 2025 18:15:46 -0400
Subject: [PATCH 17/18] Wipe cache if wrong len pt2

---
 fms_fsdp/utils/dataset_utils.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py
index f2d7823f..9ba14958 100644
--- a/fms_fsdp/utils/dataset_utils.py
+++ b/fms_fsdp/utils/dataset_utils.py
@@ -706,7 +706,7 @@ def _pad_buffer(self):
         if len(self.buffer) < self.window_size:
             self.buffer += [
                 [],
-            ] * (len(self.buffer) - self.buffer_size)
+            ] * (self.window_size - len(self.buffer))
 
     def state_dict(self):
         # Write generator state manually

From 922d37338e851863e4ca34c3cb75b81876c0731b Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Fri, 13 Jun 2025 18:43:30 -0400
Subject: [PATCH 18/18] Make doc slicing flaggable

---
 fms_fsdp/config/training.py        |  1 +
 fms_fsdp/utils/dataloader_utils.py | 11 ++++++-----
 2 files changed, 7 insertions(+), 5 deletions(-)

diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py
index 45c4a87a..5c44cc04 100644
--- a/fms_fsdp/config/training.py
+++ b/fms_fsdp/config/training.py
@@ -30,6 +30,7 @@ class train_config:
     doc_breakpoint: int = 65_536
     filter_exp: int = 2
     target_doclen: int = 8192
+    slice_rate: float = 0.0
 
     # FIM training
     psm_rate: float = 0.0
diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py
index 6c657ea5..67cead33 100644
--- a/fms_fsdp/utils/dataloader_utils.py
+++ b/fms_fsdp/utils/dataloader_utils.py
@@ -142,11 +142,12 @@ def get_data_loader(cfg, rank, world_size, dp_degree):
     # Shuffle outputs in length 10k buffer. Consecutive lines appear 10k steps apart on average.
     data = PreloadBufferDataset(data, 1000)
     # Slice and rearrange docs to force long-context retrieval
-    data = DocSliceDataset(
-        data,
-        cfg.eos_token,
-        slice_rate=.75,
-    )
+    if cfg.slice_rate > 0:
+        data = DocSliceDataset(
+            data,
+            cfg.eos_token,
+            slice_rate=cfg.slice_rate,
+        )
     # Apply FIM transformation if needed
     if fim_training:
         data = FIMDataset(