diff --git a/code-of-conduct.md b/code-of-conduct.md new file mode 100644 index 00000000..a5a920fc --- /dev/null +++ b/code-of-conduct.md @@ -0,0 +1,3 @@ +# Foundation Model Stack Community Code of Conduct + +Please refer to [Foundation Model Stack Community Code of Conduct](https://github.com/foundation-model-stack/foundation-model-stack/blob/main/code-of-conduct.md). diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py index 1d072958..5c44cc04 100644 --- a/fms_fsdp/config/training.py +++ b/fms_fsdp/config/training.py @@ -26,6 +26,18 @@ class train_config: strip_tokens: str = "" logical_shards: int = 1024 num_workers: int = 1 + doc_cutoff: int = 1_000_000 + doc_breakpoint: int = 65_536 + filter_exp: int = 2 + target_doclen: int = 8192 + slice_rate: float = 0.0 + + # FIM training + psm_rate: float = 0.0 + spm_rate: float = 0.0 + fim_pre: int = 1 + fim_mid: int = 2 + fim_suf: int = 3 # fsdp policies sharding_strategy: str = "hsdp" @@ -72,3 +84,9 @@ class train_config: stage2_prompt_length: int = 64 stage2_batch_size: int = 96 stage2_seq_length: int = 256 + + # context parallel + cp: bool = False + cp_mamba_impl: str = "allgather" # "allgather" or "serial" + cp_attn_impl: str = "zigzag" # "zigzag" or "ring" + cp_over_world: bool = False diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index f4e7628c..85d4253a 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -126,7 +126,6 @@ def get_model_config(model_variant): nlayers=24, hidden_grow_factor=8 / 3, max_expected_seq_len=4096, - rope_theta=500000.0, ) elif model_variant == "llama3_70b": model_config = LLaMAConfig( @@ -175,7 +174,7 @@ def get_model_config(model_variant): "num_heads_kv": 8, "out_proj_bias": False, "qkv_proj_bias": False, - "rotary_emb_dim": 64, + "rotary_emb_dim": 0, }, "rms_norm": True, "residual_in_fp32": True, diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py index 4b811d6d..67cead33 100644 --- a/fms_fsdp/utils/dataloader_utils.py +++ b/fms_fsdp/utils/dataloader_utils.py @@ -5,6 +5,8 @@ AutoHandler, BufferDataset, CheckpointDataset, + DocSliceDataset, + FIMDataset, ParquetHandler, PreloadBufferDataset, PreprocessDataset, @@ -12,6 +14,7 @@ ScalableShardDataset, StreamingDocDataset, ) +from math import ceil _handler_map = { @@ -57,9 +60,9 @@ def __iter__(self): return torch.utils.data.DataLoader(data, batch_size=cfg.batch_size) -def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): +def get_data_loader(cfg, rank, world_size, dp_degree): """ - Pytorch dataloader for stateful, distributed, and rescalable causal language model (CLM) training. + Pytorch dataloader for stateful, distributed, and rescalable language model training. Assumes underlying data is sequences of integer values. ... Args @@ -70,12 +73,21 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): Rank of current distributed worker. Used for handling dataset sharding logic. world_size : int Number of distributed workers. Used for handling dataset sharding logic. - postprocess : List[Callable] - Any task-specific postprocessing to apply before handing over data. Steps will apply in - the order provided by the user. For CLM training, use postprocess=[causal_lm]. """ - datasets, weights = parse_data_args(cfg.datasets, cfg.weights) + do_cp = False + if dp_degree != world_size: + do_cp = True + cp_worldsize = world_size // dp_degree + cp_rank = rank % cp_worldsize + world_size = dp_degree + rank = rank // cp_worldsize + + fim_training = cfg.psm_rate + cfg.spm_rate > 0 + if fim_training: + assert cfg.bos_token is None, "No BOS in FIM training. Did you mean fim_pre?" + + datasets, weights, cols = parse_data_args(cfg.datasets, cfg.weights, cfg.col_name) # Base streaming dataset. Returns doc chunks in sequence. # Implements dataset sampling and rescalability. @@ -87,9 +99,9 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): cfg.file_type in _handler_map ), f"File type {cfg.file_type} is not recognized ({list(_handler_map.keys())})" if cfg.file_type == "hf_parquet" or cfg.file_type == "auto": - filehandler = _handler_map[cfg.file_type](cfg.tokenizer_path, cfg.col_name) + filehandler = _handler_map[cfg.file_type](cfg.tokenizer_path, cols, cfg.doc_cutoff) else: - filehandler = _handler_map[cfg.file_type] + filehandler = _handler_map[cfg.file_type](cols) # Base reader layer data = StreamingDocDataset( cfg.data_path, @@ -99,8 +111,10 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): cfg.eos_token, bos_token=cfg.bos_token, strip_tokens=set(droplist), - min_length=3, + min_length=cfg.target_doclen, seed=cfg.seed, + filter_exp=cfg.filter_exp, + max_consecutive_chunks=ceil(cfg.doc_breakpoint/1024), ) # Add rescaling/resharding data = ScalableShardDataset( @@ -120,18 +134,40 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): # Wrap above dataset in packing logic to form constant-length lines. data = BufferDataset( data, - cfg.seq_length if causal_lm not in postprocess else cfg.seq_length + 1, + cfg.seq_length + 1, bos_token=cfg.bol_token, eos_token=cfg.eol_token, pack_hard=True, ) # Shuffle outputs in length 10k buffer. Consecutive lines appear 10k steps apart on average. - data = PreloadBufferDataset(data, 10000) - - # Apply desired postprocessing steps in sequence + data = PreloadBufferDataset(data, 1000) + # Slice and rearrange docs to force long-context retrieval + if cfg.slice_rate > 0: + data = DocSliceDataset( + data, + cfg.eos_token, + slice_rate=cfg.slice_rate, + ) + # Apply FIM transformation if needed + if fim_training: + data = FIMDataset( + data, + cfg.eos_token, + cfg.psm_rate, + cfg.spm_rate, + pre_token=cfg.fim_pre, + mid_token=cfg.fim_mid, + suf_token=cfg.fim_suf, + ) + # Transform to tensors data = PreprocessDataset(data, torch.IntTensor) - for p in postprocess: - data = PreprocessDataset(data, p) + # Apply CLM transformation + data = PreprocessDataset(data, causal_lm) + # Apply CP chunking if using CP + if do_cp: + def chunk(x): + return x[(cp_rank*x.size(0))//cp_worldsize : ((cp_rank+1)*x.size(0))//cp_worldsize] + data = PreprocessDataset(data, lambda x: (chunk(x[0]), chunk(x[1]))) # Enable auto-saving data = CheckpointDataset( @@ -146,7 +182,7 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): ) -def parse_data_args(datas, weights): +def parse_data_args(datas, weights, cols): # Convert csv inputs into corresponding lists of values def splitstrip(x): if isinstance(x, str): @@ -160,4 +196,5 @@ def splitstrip(x): datas = splitstrip(datas) weights = [float(x) for x in splitstrip(weights)] - return datas, weights + cols = splitstrip(cols) + return datas, weights, cols diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index aedc5862..9ba14958 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -180,16 +180,20 @@ def load_state_dict(self, state_dicts, sharded_input=False): self.load_worldsize = len(state_dicts) state_dicts = _shard_inclusive(state_dicts, self.rank, self.worldsize) if self.load_worldsize == self.worldsize: - [ - setattr(self, flag, state_dicts[0][self.statename(flag)]) - for flag in self.state_params + self.reshard_params - ] + for flag in self.state_params + self.reshard_params: + if self.statename(flag) in state_dicts[0]: + setattr(self, flag, state_dicts[0][self.statename(flag)]) + elif self.rank == 0: + logging.warning(f"Dataloader state key {self.statename(flag)} not present in checkpoint!") else: for flag in self.reshard_params: - reshard = self._reshard( - [sd[self.statename(flag)] for sd in state_dicts] - ) - setattr(self, flag, reshard) + if self.statename(flag) in state_dicts[0]: + reshard = self._reshard( + [sd[self.statename(flag)] for sd in state_dicts] + ) + setattr(self, flag, reshard) + elif self.rank == 0: + logging.warning(f"Dataloader state key {self.statename(flag)} not present in checkpoint!") return state_dicts def load_from_path(self, path: str): @@ -343,8 +347,8 @@ class ArrowHandler(_ShardFileHandler): Non-standard data format, though. """ - def __init__(self, col_name: str = "tokens"): - self.col_name = col_name + def __init__(self, col_names: List[str] = ["text", "contents", "tokens"]): + self.col_names = col_names def is_legal(self, filepath: str): return "arrow" in os.path.splitext(filepath)[1] @@ -356,7 +360,18 @@ def length(self, path: str): return self.open(path).num_record_batches def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set): - doc = reader.get_batch(index)[self.col_name] + assert ( + index < reader.num_record_batches + ), f"Illegal index {index} in set of {reader.num_record_batches} documents" + frame = reader.get_batch(index) + doc = None + for name in self.col_names: + if name in frame.column_names: + doc = frame[name] + break + assert ( + doc is not None + ), f"None of column names {self.col_names} found in file headers {frame.column_names}" if len(doc) > 0 and doc[0].as_py() in drop_tokens: doc = doc.slice(1, len(doc) - 1) # Recheck len for edge case where doc=[eos] @@ -371,28 +386,37 @@ def slice(self, doc: pa.UInt32Array, index: int, n_pull: int) -> List: class ParquetHandler(_ShardFileHandler): """ Reader for indexable parquet shard files, common in HF datasets. - Here we assume reasonably small shard files (<5Gb) and documents (<100k tokens), + Here we assume reasonably small shard files (<5Gb) and truncate docs to max_doclen characters, as we rely on parquet/pandas for efficient file reading, and tokenize entire documents before getting/slicing. However, this is a standard and widely-used data format. """ - def __init__(self, tokenizer_path: str, col_name: str = "text"): + def __init__(self, tokenizer_path: str, col_names: List[str] = ["text", "contents", "tokens"], max_doclen: int = 1_000_000): self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) - self.col_name = col_name + self.col_names = col_names + self.max_doclen = max_doclen def is_legal(self, filepath: str): return "parquet" in os.path.splitext(filepath)[1] def open(self, path: str): - return pq.read_pandas(path, columns=[self.col_name], partitioning=None)[ - self.col_name - ] + names = pq.read_metadata(path).schema.names + match = None + for name in self.col_names: + if name in names: + match = name + break + assert match is not None, f"None of column names {self.col_names} found in file headers {names}" + return pq.read_pandas(path, columns=[match], partitioning=None)[match] def length(self, path: str): return pq.read_metadata(path).num_rows def get(self, reader, index: int, drop_tokens: Set): - doc = self.tokenizer(str(reader[index]))["input_ids"] + assert ( + index < reader.length() + ), f"Illegal index {index} in set of {reader.length()} documents" + doc = self.tokenizer(str(reader[index])[: self.max_doclen])["input_ids"] if len(doc) > 0 and doc[0] in drop_tokens: doc = doc[1:] # Recheck len for edge case where doc=[eos] @@ -405,9 +429,9 @@ def slice(self, doc: List, index: int, n_pull: int) -> List: class AutoHandler(_ShardFileHandler): - def __init__(self, tokenizer_path: str, col_name: str = "text"): - self.PHandler = ParquetHandler(tokenizer_path, col_name) - self.AHandler = ArrowHandler() + def __init__(self, tokenizer_path: str, col_names: List[str] = ["text", "contents", "tokens"], max_doclen: int = 1_000_000): + self.PHandler = ParquetHandler(tokenizer_path, col_names, max_doclen) + self.AHandler = ArrowHandler(col_names) self.current = _ShardFileHandler() def is_legal(self, filepath: str): @@ -651,13 +675,19 @@ def __init__(self, dataset: _StatefulDataset, window_size: int): def __iter__(self): dataset = iter(self.dataset) + # Pad out buffer if needed + self._pad_buffer() + first_draw = next(dataset) while True: - # Pad out buffer if needed - self._pad_buffer() + # If buffer entries have wrong length, reset buffer + if len(first_draw) != len(self.buffer[0]): + self.buffer = [] + self.buffer_size = 0 + self._pad_buffer() # If buffer is undersized, add a datapoint if self.buffer_size < self.window_size: - self.buffer[self.buffer_size] = next(dataset) + self.buffer[self.buffer_size] = next(dataset) if self.buffer_size > 0 else first_draw self.buffer_size += 1 # Swap out randomly sampled value from buffer. @@ -673,10 +703,10 @@ def __iter__(self): yield out def _pad_buffer(self): - if self.buffer_size < self.window_size: + if len(self.buffer) < self.window_size: self.buffer += [ [], - ] * (self.window_size - self.buffer_size) + ] * (self.window_size - len(self.buffer)) def state_dict(self): # Write generator state manually @@ -684,6 +714,8 @@ def state_dict(self): # Prune buffer so it can be resharded in future self.buffer = self.buffer[: self.buffer_size] out = super().state_dict() + # Pad buffer back out again + self._pad_buffer() return out def load_state_dict(self, state_dicts, sharded_input=False): @@ -694,6 +726,224 @@ def load_state_dict(self, state_dicts, sharded_input=False): # Manually set buffer size self.buffer_size = len(self.buffer) return sharded_dicts + + +class DocSliceDataset(_WrapperDataset): + """ + Wrapper for a StatefulDataset that implements document slicing. + ... + Args + ---- + dataset : _StatefulDataset + Fully instantiated dataset + delimiter_token : int + Value used for delimiter + slice_rate : float + Proportion of documents to slice + overlap : int + Number of tokens to overlap for slice retrieval + """ + + def __init__(self, dataset: _StatefulDataset, delimiter_token: int, slice_rate: float = 0.5, overlap: int = 3): + super().__init__(dataset) + self.g_state = None + self.generator = torch.Generator().manual_seed(self.rank) + self.state_params = ["g_state"] + self.delimiter = delimiter_token + self.slicerate = slice_rate + self.overlap = overlap + + def __iter__(self): + dataset = iter(self.dataset) + while True: + inp = next(dataset) + inplen = len(inp) + doclist = [] + last_delim = 0 + if self.delimiter not in inp: + yield inp + else: + for i in range(len(inp)): + if inp[i] == self.delimiter: + doclist.append(inp[last_delim:i]) + last_delim = i+1 + doclist.append(inp[last_delim:]) + # Pull out any short caps + begin = [] + end = [] + if len(doclist[0])//3 <= self.overlap: + begin = doclist[0] + doclist = doclist[1:] + if len(doclist[-1])//3 <= self.overlap: + end = doclist[-1] + doclist = doclist[:-1] + # Figure out which docs to slice + slice = [] + unslice = [] + for doc in doclist: + if torch.rand(1, generator=self.generator) < self.slicerate and len(doc)//3 > self.overlap: + slice.append(doc) + else: + unslice.append(doc) + if len(slice) <= 1: + yield inp + else: + # Perform slicing + sliced = [] + for doc in slice: + i = torch.randint(0, len(doc)//3, [1], generator=self.generator).item() + len(doc)//3 + sliced.append([doc[:i], doc[i-self.overlap:]]) + slice = sliced + doclist = [slice[0][0], slice[1][0], slice[0][1], slice[1][1]] + for docpair in slice[2:]: + inds = torch.randperm(len(doclist)+1, generator=self.generator)[:2].tolist() + inds.sort() + inds[1] += 1 + doclist = doclist[:inds[0]] + [docpair[0]] + doclist[inds[0]:inds[1]-1] + [docpair[1]] + doclist[inds[1]-1:] + for doc in unslice: + i = torch.randint(0, len(doclist)+1, [1], generator=self.generator).item() + doclist = doclist[:i] + [doc] + doclist[i:] + out = begin + [self.delimiter] + for doc in doclist: + out = out + doc + out.append(self.delimiter) + out = out + end + yield out[:inplen] + + def state_dict(self): + # Write generator state manually + self.g_state = self.generator.get_state() + out = super().state_dict() + return out + + def load_state_dict(self, state_dicts, sharded_input=False): + sharded_dicts = super().load_state_dict(state_dicts, sharded_input) + # Manually set generator state if it exists + if self.g_state is not None: + self.generator.set_state(self.g_state) + return sharded_dicts + + +class FIMDataset(_WrapperDataset): + """ + Wrapper for a StatefulDataset that implements Fill-In-the-Middle training + (https://arxiv.org/pdf/2207.14255). + Input should be a packed sequence (i.e. call BufferDataset before FIMDataset). + Breaks sequence apart into component document spans, and for each document span + of sufficient length, transforms with specified probability into: + PSM mode:
(prefix)(suffix) (middle) + SPM mode: (suffix) (prefix) (middle) + The new delimiter tokens can be omitted by passing in None. + Any extra tokens after transformation are dropped from the end of the sequence. + ... + Args + ---- + dataset : _StatefulDataset + Fully instantiated dataset + delimiter_token : any + Token used to indicate document boundaries + psm_rate : float + Chance to transform into PSM. Cannot exceed 1. + spm_rate : float + Chance to transform into SPM. Cannot exceed 1. + min_len : int + Minimum document length to perform FIM transformation + pre_token : any | none + Token used to indicate prefix section of the document + mid_token : any | none + Token used to indicate middle infill section of the document + suf_token : any | none + Token used to indicate suffix section of the document + """ + + def __init__( + self, + dataset: _StatefulDataset, + delimiter_token: Any, + psm_rate: float = 0.0, + spm_rate: float = 0.0, + min_len: int = 10, + pre_token=None, + mid_token=None, + suf_token=None, + ): + super().__init__(dataset) + assert ( + psm_rate + spm_rate > 0 + ), f"FIM training requires SPM or PSM transformation. Please specify a nonzero psm_rate or spm_rate." + assert ( + psm_rate + spm_rate <= 1 + ), f"Combined psm_rate {psm_rate} and spm_rate {spm_rate} probabilities cannot exceed 1." + self.psm = psm_rate + self.spm = spm_rate + self.delimiter = delimiter_token + self.min_len = min_len + self.pref = pre_token + self.suff = suf_token + self.midd = mid_token + + self.g_state = None + self.generator = torch.Generator().manual_seed(self.rank) + self.state_params = ["g_state"] + + def __iter__(self): + dataset = iter(self.dataset) + while True: + inp = next(dataset) + len_ = len(inp) + i_eos = [0] + [i for i, x in enumerate(inp) if x == self.delimiter] + [len_] + docs = [ + inp[i_eos[j] + 1 : i_eos[j + 1]] for j in range(len(i_eos) - 1) + ] # list[list[any]] + out = [] + for i in range(len(docs)): + doc = docs[i] + if len(docs[i]) >= self.min_len: + # decide psm, spm, or nothing + thresh = torch.rand([1], generator=self.generator).item() + if thresh < self.psm + self.spm: + # Split doc + doc = [] + if self.pref: + doc = [self.pref] + splits = torch.randint( + 0, len(docs[i]), [2], generator=self.generator + ).tolist() + pre = docs[i][: min(splits)] + mid = docs[i][min(splits) : max(splits)] + suf = docs[i][max(splits) :] + + if thresh < self.psm: + # PSM transformation + doc += pre + if self.suff: + doc.append(self.suff) + doc += suf + if self.midd: + doc.append(self.midd) + doc += mid + else: + # SPM transformation + if self.suff: + doc.append(self.suff) + doc += suf + if self.midd: + doc.append(self.midd) + doc += pre + mid + out += doc + [self.delimiter] + yield out[:len_] + + def state_dict(self): + # Write generator state manually + self.g_state = self.generator.get_state() + return super().state_dict() + + def load_state_dict(self, state_dicts, sharded_input=False): + sharded_dicts = super().load_state_dict(state_dicts, sharded_input) + # Manually set generator state if it exists + if self.g_state is not None: + self.generator.set_state(self.g_state) + return sharded_dicts class BufferDataset(_WrapperDataset): @@ -841,10 +1091,10 @@ class StreamingDocDataset(_StatefulDataset): Documents below this length are skipped max_chunksize : int Maximum sequence length to return. Break long docs into chunks of this size or shorter. + max_consecutive_chunks : int + Number of doc chunks to emit before manually inserting EOS and resuming later. verbose : bool Track setup progress? - shuffle : bool - Shuffle shard file and document orders? (Disable for simple testing) """ def __init__( @@ -859,7 +1109,9 @@ def __init__( seed: int = 42, min_length: int = 1, max_chunksize: int = 1024, + max_consecutive_chunks: int = 256, verbose: bool = False, + filter_exp: int = 2, ): super().__init__(datapath, rank, worldsize) self.seed = seed @@ -871,7 +1123,9 @@ def __init__( self.eos = delimiter_token self.bos = bos_token self.drop = strip_tokens + self.max_consec = max_consecutive_chunks self.verbose = verbose + self.filter_exp = filter_exp self.docset: List[ Any ] = [] # map of doc indices to (shardid, min docid, max docid) @@ -885,6 +1139,8 @@ def __init__( self.tokens_seen = 0 self.docs_seen = 0 self.percent_seen = 0 + self.has_yielded = False + self.consec = 0 self.state_params = [ "dataset", @@ -895,6 +1151,8 @@ def __init__( "docs_seen", "percent_seen", "lcg_state", + "g_state", + "consec", ] # Setup flags @@ -902,6 +1160,9 @@ def __init__( self._len = 0 self.dataset = "" self.lcg_state = 0 + self.g_state = None + + self.g = None def setup(self): """ @@ -922,22 +1183,15 @@ def setup(self): # listdir, assemble shardfraglist (ind -> shard, frag) shards = [ os.path.join(root, name)[len(datapath) + 1 :] - for root, dirs, files in os.walk(datapath, topdown=False) + for root, dirs, files in os.walk(datapath, topdown=False, followlinks=True) for name in files if self.filehandler.is_legal(os.path.join(root, name)) + and os.path.getsize(os.path.join(root, name)) > 1_000_000 + # 1mb minimum file size to prevent empty files ] shards.sort() # Ensure consistent sharding across machines - start_frag = (self.rank * self.worldsize * len(shards)) // self.worldsize - end_frag = ( - (self.rank + 1) * self.worldsize * len(shards) - ) // self.worldsize - shardfrags = [ - (shards[i // self.worldsize], i % self.worldsize) - for i in range(start_frag, end_frag) - ] - - # Assemble length of each owned shard file + # Find metadata file countfiles = [] if os.path.exists(os.path.join(pardir, "meta")): countfiles = [ @@ -945,55 +1199,78 @@ def setup(self): for x in os.listdir(os.path.join(pardir, "meta")) if "counts" in x and "csv" in x ] - doc_counts = {} if len(countfiles) > 0: # Count file exists, use it countpath = os.path.join(pardir, "meta", countfiles[0]) + else: + countpath = "" + + # Use shard file sizes to perform partitioning + # Create shardlist of form shardid -> [start%, end%] + if len(countfiles) > 0: + sizes = {} with open(countpath, "r") as csvfile: reader = csv.DictReader(csvfile) for row in reader: fullpath = row["dataset/filename"] - prefix = fullpath.find("/" + dataset) + 1 - if prefix > 0: + prefix = fullpath.find(dataset + "/") + if prefix >= 0: + key = fullpath[prefix + len(dataset) + 1 :] + sizes[key] = int(row["size"]) + shard_sizes = [sizes[shard] for shard in shards] + else: + shard_sizes = [ + os.path.getsize(os.path.join(datapath, shard)) for shard in shards + ] + shard_sizes = [s / sum(shard_sizes) for s in shard_sizes] + start = self.rank / self.worldsize + end = (self.rank + 1) / self.worldsize + shardset = {} + tally = 0 + for i in range(len(shards)): + if tally <= end and tally + shard_sizes[i] >= start: + shardset[shards[i]] = [ + min(max((start - tally) / shard_sizes[i], 0), 1), + min(max((end - tally) / shard_sizes[i], 0), 1), + ] + tally += shard_sizes[i] + + # Assemble length of each owned shard file + doc_counts = {} + if len(countfiles) > 0: + # Count file exists, use it + with open(countpath, "r") as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + fullpath = row["dataset/filename"] + prefix = fullpath.find(dataset) + if prefix >= 0: key = fullpath[prefix + len(dataset) + 1 :] doc_counts[key] = int(row["documents"]) else: # Count file does not exist, touch every owned file for length - unique_shardfiles = set(shard for shard, frag in shardfrags) doc_counts = { shard: self.filehandler.length(os.path.join(datapath, shard)) - for shard in unique_shardfiles + for shard in shardset } - # Read shardfrags, assemble doc list for each file shard (aggregating over fragments): - ndocs = -1 - docset = {} # shardid -> (min docid, max docid) - for i, (shard, frag) in enumerate(shardfrags): - ndocs = doc_counts[shard] - doc_start = (ndocs * frag) // self.worldsize - doc_end = ( - ndocs * frag + ndocs - ) // self.worldsize - 1 # Inclusive upper bound - if shard not in docset: - docset[shard] = [doc_start, doc_end] - min_d, max_d = docset[shard] - if doc_start < min_d: - docset[shard][0] = doc_start - if doc_end > max_d: - docset[shard][1] = doc_end - - # Add shard entries to self.docset + # Assemble doc list for each file shard + # Create docset of form [shardid, min docid, max docid] doccount = 0 - for shardid in docset: - min_d = docset[shardid][0] - max_d = docset[shardid][1] - self.docset.append((shardid, min_d, max_d)) - doccount += max_d - min_d + 1 + for shard in shardset: + ndocs = doc_counts[shard] + if ndocs > 0: + doc_start = int(ndocs * shardset[shard][0]) + doc_end = max( + doc_start, int(ndocs * shardset[shard][1]) - 1 + ) # inclusive upper bound + self.docset.append([shard, doc_start, doc_end]) + doccount += doc_end - doc_start + 1 self._len = doccount if self.verbose: logging.info( - f" Worker {self.rank} ingested {len(shardfrags)} shard fragments from {dataset}" + f" Worker {self.rank} ingested {len(self.docset)} shard fragments from {dataset}" ) # Shuffle shard files - guaranteed inconsistent across workers @@ -1002,6 +1279,7 @@ def setup(self): random.shuffle(self.docset) # Setup doc shuffle - same guarantee self.lcg_state = seed + self.g = torch.Generator().manual_seed(self.rank) def _get_docid(self, i): """ @@ -1048,8 +1326,11 @@ def _construct_chunk(self, j, doc, n_chunks): # Add bos/eos tokens if needed if self.bos is not None and j == 0: chunk = [self.bos] + chunk - if j == n_chunks - 1: + if j == n_chunks - 1 or self.consec == self.max_consec: chunk = chunk + [self.eos] + self.consec = 0 + else: + self.consec += 1 return chunk def _random_map_docid(self, size): @@ -1094,10 +1375,9 @@ def __iter__(self): doclcg = self._random_map_docid(docrange) docid = doclcg + mindoc doc = self.filehandler.get(reader, docid, self.drop) - if len(doc) == 0: - continue doclen = len(doc) + 1 if self.bos is None else len(doc) + 2 - if doclen >= self.min_length: + keep_chance = (doclen/self.min_length)**self.filter_exp + if len(doc) > 0 and torch.rand(1, generator=self.g).item() < keep_chance: n_chunks = math.ceil(doclen / self.chunksize) for j in range(n_chunks): if i == 0 and j < residual_chunks: @@ -1110,6 +1390,7 @@ def __iter__(self): self.percent_seen = ( self.docs_seen * 100 / (self._len + 1e-9) ) + self.has_yielded = True yield self._construct_chunk(j, doc, n_chunks) # Advance RNG state @@ -1123,15 +1404,24 @@ def __iter__(self): newpath = os.path.join(self.datapath, shardid) path, reader = self._get_reader(path, newpath, reader) doc = self.filehandler.get(reader, docid, self.drop) - if len(doc) == 0: - continue doclen = len(doc) + 1 if self.bos is None else len(doc) + 2 - if doclen >= self.min_length: + keep_chance = (doclen/self.min_length)**self.filter_exp + if len(doc) > 0 and torch.rand(1, generator=self.g).item() < keep_chance: n_chunks = math.ceil(doclen / self.chunksize) for j in range(residual_chunks): self.chunk_index = j + self.has_yielded = True yield self._construct_chunk(j, doc, n_chunks) + # Check that epoch was non-empty + assert self.has_yielded, f"Empty logical shard detected: {self.dataset, self.docset}" + + def state_dict(self): + # Write generator state manually + self.g_state = self.g.get_state() + out = super().state_dict() + return out + def load_state_dict(self, state_dicts, sharded_input=False): self.setup() assert ( @@ -1142,6 +1432,9 @@ def load_state_dict(self, state_dicts, sharded_input=False): assert ( d == self.dataset ), f"Dataset mismatch: checkpoint contains {self.dataset}, expected {d}" + # Manually set generator state if it exists + if self.g_state is not None: + self.g.set_state(self.g_state) return out @@ -1203,12 +1496,12 @@ def setup(self): if not self.is_setup: _StatefulDataset.setup(self) n_logical_shards = self.total_shards + assert ( + n_logical_shards % self.worldsize == 0 + ), f"Total workers {self.worldsize} must divide n_logical_shards {n_logical_shards} evenly" logicals = list(range(n_logical_shards)) self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize) self.n_logicals = n_logical_shards // self.worldsize - assert ( - len(self.logicals_owned) == self.n_logicals - ), "(world size * num workers) does not divide logical shards evenly" # Build logical shards for i in range(self.n_logicals): @@ -1225,6 +1518,9 @@ def setup(self): ) [d.setup() for d in self.data] self.n_docs_remaining = [d._len for d in self.data] + assert ( + sum(self.n_docs_remaining) > 0 + ), f"No documents detected in shard {self.rank} of {self.datapath}" self.generator = torch.Generator().manual_seed(self.rank) @@ -1232,14 +1528,16 @@ def __iter__(self): self.setup() # Grab one doc at a time in random order data = [iter(d) for d in self.data] + # Reset if we're rescaling into a prematurely finished epoch + # (i.e. [1,1,0,0,0,0] into [1,1,0] [0,0,0] ) + if sum(self.n_docs_remaining) == 0: + self.n_docs_remaining = [d._len for d in self.data] + self.generator.manual_seed(self.rank) while True: # Sample logical shard (or load from ckp) if self.current_reader is not None: ind = self.current_reader else: - assert ( - sum(self.n_docs_remaining) > 0 - ), f"No documents detected in {self.datapath}" ind = torch.multinomial( torch.tensor(self.n_docs_remaining, dtype=torch.float), 1, @@ -1331,6 +1629,10 @@ def __init__( ] ) assert len(self.datasets) > 0, "You must specify at least one dataset" + for d in datasets: + assert os.path.exists( + os.path.join(datapath, d) + ), f"Invalid subdataset path: {os.path.join(datapath, d)}" if weights is not None: assert len(weights) == len( diff --git a/fms_fsdp/utils/train_utils.py b/fms_fsdp/utils/train_utils.py index ef421f6f..ba05636d 100644 --- a/fms_fsdp/utils/train_utils.py +++ b/fms_fsdp/utils/train_utils.py @@ -2,7 +2,6 @@ from dataclasses import asdict from functools import partial - try: import packaging.version except ImportError: @@ -30,6 +29,7 @@ def train( checkpointer, start_step, tokens_seen, + cp_degree: int = 1, ): if cfg.tracker: if cfg.tracker not in ["wandb", "aim"]: @@ -44,19 +44,22 @@ def train( except ImportError: raise ImportError("tracker is set to wandb but wandb is not installed.") if rank == 0: - print(f"--> wandb is enabled!") + print("--> Started initializing wandb", flush=True) try: wandb.init( project=project_name, dir=tracker_dir, resume="allow", id=run_id, + # mode='offline', + settings=wandb.Settings(init_timeout=3600), ) except wandb.errors.UsageError: raise ValueError( "wandb failed to init, did you pass your wandb api key via WANDB_API_KEY?" ) wandb.config = asdict(cfg) + print(f"--> wandb is enabled!", flush=True) if cfg.tracker == "aim": try: @@ -64,7 +67,7 @@ def train( except ImportError: raise ImportError("tracker is set to aim but aim is not installed.") if rank == 0: - print(f"--> aim is enabled!") + print("--> aim is enabled!") run = Run( experiment=project_name, repo=tracker_dir, @@ -89,8 +92,9 @@ def train( output = output.logits if hasattr(output, "logits") else output ce_loss = torch.nn.CrossEntropyLoss() loss = ce_loss(output.view(-1, output.size(-1)), label.view(-1).long()) - + loss = loss + .0001 * torch.logsumexp(output, dim=-1).pow(2).mean() loss.backward() + ddp_stats[1] += model.clip_grad_norm_(cfg.grad_clip_thresh).item() optimizer.step() scheduler.step() @@ -101,14 +105,18 @@ def train( if profiler: profiler.step() - if batch_idx % cfg.report_interval == 0: + if batch_idx % cfg.report_interval == 0 or batch_idx == start_step + 1: dist.all_reduce(ddp_stats, op=dist.ReduceOp.SUM) train_loss = ddp_stats[0] / ddp_stats[2] g_norm = ddp_stats[1] / ddp_stats[2] elapsed_time = time.time() - loop_start world_size = int(os.environ["WORLD_SIZE"]) new_tokens_seen = ( - (batch_idx - start_step) * world_size * cfg.batch_size * cfg.seq_length + (batch_idx - start_step) + * world_size + * cfg.batch_size + * cfg.seq_length + // cp_degree ) if rank == 0: total_tokens_seen = tokens_seen + new_tokens_seen @@ -118,10 +126,10 @@ def train( current_step_time = (time.time() - start) / cfg.report_interval overall_step_time = elapsed_time / (batch_idx - start_step) current_throughput = int( - cfg.batch_size * cfg.seq_length / current_step_time + cfg.batch_size * cfg.seq_length / cp_degree / current_step_time ) overall_throughput = int( - cfg.batch_size * cfg.seq_length / overall_step_time + cfg.batch_size * cfg.seq_length / cp_degree / overall_step_time ) reserved_mem = torch.cuda.max_memory_reserved( device=torch.cuda.current_device() @@ -145,7 +153,8 @@ def train( "overall token per day:", int(new_tokens_seen / elapsed_time * 3600 * 24), ) - if cfg.tracker: + print(f"Total tok/step: {world_size * cfg.batch_size * cfg.seq_length}") + if cfg.tracker and batch_idx > start_step + 1: vals_to_track = { "learning rate": current_lr, "loss": current_loss, @@ -201,11 +210,11 @@ def get_mixed_precision_policy(cfg, rank): if bf16_ready: mixed_precision_policy = bfSixteen if rank == 0: - print(f"bFloat16 enabled for mixed precision - using bfSixteen policy") + print("bFloat16 enabled for mixed precision - using bfSixteen policy") else: mixed_precision_policy = fpSixteen if rank == 0: - print(f"FP16 enabled") + print("FP16 enabled") else: mixed_precision_policy = None diff --git a/fms_to_hf_mamba_transformers.py b/fms_to_hf_mamba_transformers.py new file mode 100644 index 00000000..57b3ec80 --- /dev/null +++ b/fms_to_hf_mamba_transformers.py @@ -0,0 +1,280 @@ +# coding=utf-8 +# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Modified from src/transformers/models/bamba/convert_mamba_ssm_checkpoint.py +""" + +"""This script can be used to convert checkpoints provided in the `mamba_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed.""" + +import argparse +import json +import os +import re +from os import path +from typing import Dict, Optional, Union + +import torch +from huggingface_hub import split_torch_state_dict_into_shards +from safetensors.torch import save_file + +from transformers import AutoTokenizer +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME + +from transformers.models.bamba import BambaConfig + + +def convert_state_dict_from_mamba_ssm(original_sd: Dict) -> Dict[str, torch.Tensor]: + state_dict = {} + + for orig_k, param in original_sd.items(): + k = orig_k.replace("backbone", "model") + + # for embeddings + k = k.replace("embedding", "embed_tokens") + + # for mixer + k = k.replace("mixer", "mamba") + + # for final layernorm + k = k.replace("norm_f", "final_layernorm") + + # for block layernorm + k = re.sub(r"(\d+)\.norm\.", r"\1.input_layernorm.", k) + k = re.sub(r"(\d+)\.norm2\.", r"\1.pre_ff_layernorm.", k) + + # for mlp + k = k.replace("mlp.fc2", "feed_forward.down_proj") + + if "mlp.fc1" in k: + param, param2 = torch.chunk(param, 2, dim=0) + k2 = k.replace("mlp.fc1", "feed_forward.gate_proj") + state_dict[k2] = param2 + k = k.replace("mlp.fc1", "feed_forward.up_proj") + + if ("in_proj" in k and orig_k.replace("in_proj", "conv1d") in original_sd) or ( + "out_proj" in k and orig_k.replace("out_proj", "conv1d") in original_sd + ): + # then this must be a mamba + pass + else: + # for attn + # - because mixer was replaced to mamba above + k = k.replace("mamba.out_proj", "self_attn.o_proj") + if "mamba.in_proj" in k: + m, n = param.shape + d = (m - n) // 2 + param, param2, param3 = torch.split(param, [n, d, d], dim=0) + k2 = k.replace("mamba.in_proj", "self_attn.k_proj") + state_dict[k2] = param2 + k2 = k.replace("mamba.in_proj", "self_attn.v_proj") + state_dict[k2] = param3 + k = k.replace("mamba.in_proj", "self_attn.q_proj") + + state_dict[k] = param + + return state_dict + + +# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py +def convert_ssm_config_to_hf_config( + config_ssm: Dict, + **kwargs, +) -> BambaConfig: + """Convert a config from mamba_ssm to a BambaConfig from here.""" + hf_config: BambaConfig = BambaConfig(**kwargs) + + hf_config.architectures = ["BambaForCausalLM"] + + # Set important values from config and recalculate other resulting entries + hf_config.hidden_size = config_ssm["d_model"] + hf_config.intermediate_size = config_ssm["d_intermediate"] + hf_config.mamba_n_heads = (hf_config.hidden_size * hf_config.mamba_expand) // hf_config.mamba_d_head + hf_config.num_hidden_layers = config_ssm["n_layer"] + hf_config.tie_word_embeddings = config_ssm["tie_embeddings"] + + # currently this script assumes config_ssm belongs to v2 + if config_ssm["ssm_cfg"].get("layer") != "Mamba2": + raise ValueError("Conversion script only supports Mamba2") + + # Set attention values + attn_cfg = config_ssm.get("attn_cfg") + if attn_cfg: + assert attn_cfg["causal"], "Only support non-causal attention." + assert not attn_cfg["qkv_proj_bias"], "Only support no qkv bias." + assert not attn_cfg["out_proj_bias"], "Only support no out bias." + hf_config.attn_rotary_emb = attn_cfg["rotary_emb_dim"] + hf_config.num_attention_heads = attn_cfg["num_heads"] + hf_config.num_key_value_heads = attn_cfg["num_heads_kv"] + hf_config.rope_theta = attn_cfg["rotary_emb_base"] + + attention_layer_indices = config_ssm.get("attn_layer_idx") + if attention_layer_indices: + hf_config.attn_layer_indices = attention_layer_indices + + # Padded vocab size, mostly of 16 but 32 is also very common in different models + vocab_size = config_ssm["vocab_size"] + pad_vocab_size_multiple = config_ssm["pad_vocab_size_multiple"] + if (vocab_size % pad_vocab_size_multiple) != 0: + vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) + hf_config.vocab_size = vocab_size + + return hf_config + + +def save_single_safetensor( + state_dict: Dict, + save_directory: str, + metadata: Dict, +): + save_file( + state_dict, + os.path.join(save_directory, SAFE_WEIGHTS_NAME), + metadata, + ) + + +def save_sharded_safetensors( + state_dict: Dict, + save_directory: str, + metadata: Dict, + max_shard_size: Union[int, str] = "5GB", +): + filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace( + ".safetensors", "{suffix}.safetensors" + ) + state_dict_split = split_torch_state_dict_into_shards( + state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size + ) + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + # Save the index + with open(os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + + filename_to_tensors = state_dict_split.filename_to_tensors.items() + for shard_file, tensors in filename_to_tensors: + shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors} + save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata) + + +# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py +def convert_mamba_ssm_checkpoint_file_to_huggingface_model_file( + mamba_ssm_checkpoint_path: str, + precision: str, + output_dir: str, + tokenizer_path: Optional[str] = None, + save_model: Union[bool, str] = True, +) -> None: + # load tokenizer if provided, this will be used to set the + # token_ids in the config file + token_ids = {} + if tokenizer_path: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + for key in [ + "bos_token_id", + "eos_token_id", + "pad_token_id", + ]: + id = getattr(tokenizer, key, None) + if id: + token_ids[key] = id + + # there are some configs unsettable by mamba_ssn config, so + # if there are changes from the defaults, have to pass them into + # the function + unsettables = { + "mamba_d_head": 64, + "mamba_d_state": 128, + "mamba_n_groups": 1, + "rms_norm_eps": 1e-5, + } + + # Load and save config based on name + config_path = path.join(mamba_ssm_checkpoint_path, "config.json") + with open(config_path, "r", encoding="utf-8") as json_file: + config = json.load(json_file) + + # convert the config + hf_config = convert_ssm_config_to_hf_config( + config_ssm=config, + **token_ids, + **unsettables, + ) + hf_config.save_pretrained(output_dir) + + # Load state dict of the original model and transfer to hf model + state_dict = torch.load( + path.join(mamba_ssm_checkpoint_path, "pytorch_model.bin"), + map_location="cpu", + weights_only=True, + ) + # FIXME: allow other parameters to pass in + state_dict = convert_state_dict_from_mamba_ssm(state_dict) + + # Save new model to pytorch_dump_path + dtype = torch.float32 if precision == "fp32" else (torch.bfloat16 if precision == "bf16" else torch.float16) + + save_file_fn = None + if isinstance(save_model, bool) and save_model: + save_file_fn = save_single_safetensor + elif isinstance(save_model, str) and save_model == "sharded": + save_file_fn = save_sharded_safetensors + + if save_file_fn: + save_file_fn({k: v.to(dtype) for k, v in state_dict.items()}, output_dir, metadata={"format": "pt"}) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-i", + "--mamba_ssm_checkpoint_directory", + type=str, + required=True, + help="Path to a directory containing the `pytorch_model.bin` mamba_ssm checkpoint file to be converted.", + ) + parser.add_argument( + "-p", + "--precision", + type=str, + default="fp16", + required=True, + choices=("fp32", "fp16", "bf16"), + help="The precision the model will be saved in. Select from fp32, fp16 or bf16.", + ) + parser.add_argument( + "-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to." + ) + parser.add_argument( + "-t", + "--tokenizer_model_path", + type=str, + default=None, + required=False, + help="Path to a the tokenizer file.", + ) + args = parser.parse_args() + + convert_mamba_ssm_checkpoint_file_to_huggingface_model_file( + args.mamba_ssm_checkpoint_directory, + args.precision, + args.output_dir, + args.tokenizer_model_path, + ) + diff --git a/main_training_mamba.py b/main_training_mamba.py index 3619ea25..88fb9bbc 100644 --- a/main_training_mamba.py +++ b/main_training_mamba.py @@ -4,12 +4,15 @@ import fire import torch +import torch.nn as nn import torch.optim as optim from mamba_ssm.models.config_mamba import MambaConfig from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel from mamba_ssm.modules.block import Block from torch import distributed as dist +from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.wrap import CustomPolicy from torch.optim.lr_scheduler import LambdaLR from fms_fsdp import config @@ -51,20 +54,82 @@ def main(**kwargs): Path.home(), ".triton", "cache", str(local_rank) ) - # get policy + # get policy. NOTE: @goon - overriding {wrapping_policy, param_init_fn} below block = Block ( mixed_precision_policy, - wrapping_policy, + _, sharding_strategy_policy, apply_selective_ac, - param_init_fn, + _, # NOTE: @goon - We'll override param_init_fn for mamba below ) = get_policies(cfg, rank, block) + if cfg.low_cpu_fsdp: + # NOTE: @goon - the params will be junk after using this. Only intended to be used in + # conjunction with loading proper weights from a checkpoint. + def param_init_fn(module): + module.to_empty(device=torch.cuda.current_device()) + else: + param_init_fn = None + + # Meshes for FSDP and CP. NOTE: @goon - Getting hangs and/or OOMs if I don't explicitly specify + # the FSDP mesh when using 4+ nodes with HSDP + in-node-CP. + def get_1D_world_mesh(world_size: int) -> DeviceMesh: + mesh = dist.device_mesh.init_device_mesh("cuda", (world_size,)) + return mesh + + def get_2D_world_mesh(world_size: int) -> DeviceMesh: + num_gpu_per_node = torch.cuda.device_count() + assert world_size % num_gpu_per_node == 0 + mesh = dist.device_mesh.init_device_mesh( + "cuda", + (world_size // num_gpu_per_node, num_gpu_per_node), + mesh_dim_names=("inter_node", "intra_node"), + ) + return mesh + + requires_2d_mesh = (cfg.sharding_strategy == "hsdp") or ( + cfg.cp and not cfg.cp_over_world + ) + if requires_2d_mesh: + mesh = get_2D_world_mesh(world_size) + fsdp_mesh = mesh + cp_mesh = mesh["intra_node"] if cfg.cp else None + else: + mesh = get_1D_world_mesh(world_size) + fsdp_mesh = mesh + cp_mesh = mesh if cfg.cp else None + + if cfg.cp: + cp_degree = world_size if cfg.cp_over_world else torch.cuda.device_count() + else: + cp_degree = 1 + + dp_degree = world_size // cp_degree # get model config_data = get_model_config(cfg.model_variant) mamba_config = MambaConfig(**config_data) - model = MambaLMHeadModel(mamba_config) + + if cfg.low_cpu_fsdp: + with torch.device("meta"): + model = MambaLMHeadModel( + mamba_config, + cp_mesh=cp_mesh if cfg.cp else None, + cp_mamba_impl=cfg.cp_mamba_impl if cfg.cp else None, + cp_attn_impl=cfg.cp_attn_impl if cfg.cp else None, + ) + else: + model = MambaLMHeadModel( + mamba_config, + cp_mesh=cp_mesh if cfg.cp else None, + cp_mamba_impl=cfg.cp_mamba_impl if cfg.cp else None, + cp_attn_impl=cfg.cp_attn_impl if cfg.cp else None, + ) + + def lambda_fn(module: nn.Module): + return isinstance(module, (Block, nn.Embedding)) or module is model.lm_head + + wrapping_policy = CustomPolicy(lambda_fn) if rank == 0: total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) @@ -74,7 +139,7 @@ def main(**kwargs): if rank == 0: print("Constructing datasets...") if not cfg.use_dummy_dataset: - train_loader = get_data_loader(cfg, rank, world_size) + train_loader = get_data_loader(cfg, rank, world_size, dp_degree) else: train_loader = get_dummy_loader(cfg, rank, world_size) if rank == 0: @@ -83,32 +148,60 @@ def main(**kwargs): # FSDP model = FSDP( model, + device_mesh=fsdp_mesh, auto_wrap_policy=wrapping_policy, mixed_precision=mixed_precision_policy, sharding_strategy=sharding_strategy_policy, - use_orig_params=cfg.use_torch_compile, + use_orig_params=True, device_id=torch.cuda.current_device(), limit_all_gathers=True, param_init_fn=param_init_fn, ) + if rank == 0: + print(model) # fsdp activation checkpointing if cfg.fsdp_activation_checkpointing: if rank == 0: - print(f"--> applying FSDP activation checkpointing...") + print("--> applying FSDP activation checkpointing...") apply_selective_ac(model, p=cfg.selective_checkpointing) # torch compile if cfg.use_torch_compile: if rank == 0: - print(f"--> enabling torch compile...") + print("--> enabling torch compile...") # the default accumulated_cache_size_limit=64 is not enough for 70b model, so we make it 128 here torch._dynamo.config.accumulated_cache_size_limit = 128 model = torch.compile(model) # Optimizer + # optimizer = optim.AdamW( + # model.parameters(), + # lr=cfg.learning_rate, + # betas=(0.9, 0.95), + # weight_decay=0.1, + # ) + params_with_decay = [] + params_without_decay = [] + for name, param in model.named_parameters(): + suff = name.split('.')[-1] + if 'A_log' in suff or 'D' in suff or 'dt_bias' in suff: + params_without_decay.append(param) + else: + params_with_decay.append(param) optimizer = optim.AdamW( - model.parameters(), lr=cfg.learning_rate, betas=(0.9, 0.95), weight_decay=0.1 + [ + { + "params": params_with_decay, + "weight_decay": 0.1, + }, + { + "params": params_without_decay, + "weight_decay": 0., + }, + ], + betas = (0.9, 0.95), + lr = cfg.learning_rate, ) # optionally load from checkpoint (when continue pretraining) @@ -131,19 +224,36 @@ def main(**kwargs): g["initial_lr"] = cfg.learning_rate # LR schedule + warmup_interval = min(2000, cfg.num_steps // 20) + warmup = lambda x: 1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2 # linear decay for annealing if cfg.training_stage == "annealing": - schedule = lambda x: 1 - x / cfg.num_steps - else: + warmup_interval = 1000 + schedule = lambda x: x / warmup_interval if x < warmup_interval else 1 - (x - warmup_interval) / (cfg.num_steps - warmup_interval) + elif cfg.training_stage == "cosine": # cosine decay - warmup_interval = min(2000, cfg.num_steps // 20) schedule = lambda x: min( - 1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2, + warmup(x), 0.1 + 0.5 * (1 - 0.1) * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)), ) + elif cfg.training_stage == "constant": + warmup_interval = 2000 + schedule = lambda x: (min(x, warmup_interval) / warmup_interval) + elif cfg.training_stage == "linear_to_constant": + linear_steps = 25000 + start_lr = 2e-4 + end_lr = 2e-4 + schedule = lambda x: (start_lr + (end_lr - start_lr) * min(x - start_step, linear_steps) / linear_steps) / cfg.learning_rate + elif cfg.training_stage == "annealing_with_specified_decay_steps": + warmup_interval = 2000 + total_decay_steps = 25000 + schedule = lambda x: (x - start_step) / warmup_interval if x - start_step < warmup_interval else max(0.0, 1 - (x - start_step - warmup_interval) / total_decay_steps) + else: + schedule = lambda x: 1.0 + (0.75 - 1.0) * (x / 32000) if x <= 32000 else 0.75 + scheduler = LambdaLR(optimizer, lambda x: schedule(x + start_step)) @@ -165,6 +275,7 @@ def main(**kwargs): checkpointer, start_step, tokens_seen, + cp_degree, ) checkpointer.save_single_file(cfg.num_steps, model)