From 655030b1153340a9b813e4049ab08b7d566e4aef Mon Sep 17 00:00:00 2001 From: ceci3 Date: Wed, 8 Jan 2025 10:25:17 +0800 Subject: [PATCH 1/5] [Compress] Add Compress Function for FlagScale (#308) Support the compression functionality in FlagScale by implementing the following components: - YAML-based model compression initiation - Compressor: The central entry point for compression - [ ] [TODO] Combined Algorithm: Automatic combination of compression algorithm configurations - Adapter: Adapts to different compression backends - Algo: Various compression algorithms --- examples/emu/conf/compress/compress_emu3.yaml | 22 ++ .../conf/compress/compress_emu3_w4a16.yaml | 24 ++ examples/emu/conf/compress/emu3_model.yaml | 19 ++ examples/emu/conf/config_compress.yaml | 23 ++ .../conf/compress/compress_llava_ov.yaml | 19 ++ .../compress/compress_llava_ov_w4a16.yaml | 21 ++ .../conf/compress/llava_ov_model.yaml | 19 ++ .../llava_onevision/conf/config_compress.yaml | 23 ++ flagscale/compress/__init__.py | 0 flagscale/compress/adapter.py | 150 ++++++++++ flagscale/compress/algo/__init__.py | 0 flagscale/compress/algo/algo_base.py | 13 + flagscale/compress/combined_algo.py | 8 + flagscale/compress/compressor.py | 93 ++++++ flagscale/compress/compressor_emu3.py | 62 ++++ flagscale/compress/compressor_llava_ov.py | 120 ++++++++ flagscale/runner/runner_compress.py | 280 ++++++++++++++++++ requirements/compress/requirements-dev.txt | 2 + requirements/requirements-dev.txt | 3 + run.py | 7 + tests/unit_tests/compressor/test_adapter.py | 60 ++++ .../unit_tests/compressor/test_compressor.py | 6 + tests/unit_tests/compressor/test_config.yaml | 48 +++ 23 files changed, 1022 insertions(+) create mode 100644 examples/emu/conf/compress/compress_emu3.yaml create mode 100644 examples/emu/conf/compress/compress_emu3_w4a16.yaml create mode 100644 examples/emu/conf/compress/emu3_model.yaml create mode 100644 examples/emu/conf/config_compress.yaml create mode 100644 examples/llava_onevision/conf/compress/compress_llava_ov.yaml create mode 100644 examples/llava_onevision/conf/compress/compress_llava_ov_w4a16.yaml create mode 100644 examples/llava_onevision/conf/compress/llava_ov_model.yaml create mode 100644 examples/llava_onevision/conf/config_compress.yaml create mode 100644 flagscale/compress/__init__.py create mode 100644 flagscale/compress/adapter.py create mode 100644 flagscale/compress/algo/__init__.py create mode 100644 flagscale/compress/algo/algo_base.py create mode 100644 flagscale/compress/combined_algo.py create mode 100644 flagscale/compress/compressor.py create mode 100644 flagscale/compress/compressor_emu3.py create mode 100644 flagscale/compress/compressor_llava_ov.py create mode 100644 flagscale/runner/runner_compress.py create mode 100644 requirements/compress/requirements-dev.txt create mode 100644 tests/unit_tests/compressor/test_adapter.py create mode 100644 tests/unit_tests/compressor/test_compressor.py create mode 100644 tests/unit_tests/compressor/test_config.yaml diff --git a/examples/emu/conf/compress/compress_emu3.yaml b/examples/emu/conf/compress/compress_emu3.yaml new file mode 100644 index 000000000..13e18481a --- /dev/null +++ b/examples/emu/conf/compress/compress_emu3.yaml @@ -0,0 +1,22 @@ +defaults: + - emu3_model + - _self_ + +data: + data_path: null + max_calib_data: null + max_seq_len: null + tokenzier_args: + tokenizer_path: BAAI/Emu3-Gen/ + special_tokens_file: BAAI/Emu3-Gen/emu3_vision_tokens.txt + trust_remote_code: true + +compress_args: + quantization: + - algo: + targets: ["Linear"] + ignore: ["lm_head"] + scheme: FP8_DYNAMIC + + + diff --git a/examples/emu/conf/compress/compress_emu3_w4a16.yaml b/examples/emu/conf/compress/compress_emu3_w4a16.yaml new file mode 100644 index 000000000..534727f67 --- /dev/null +++ b/examples/emu/conf/compress/compress_emu3_w4a16.yaml @@ -0,0 +1,24 @@ +defaults: + - emu3_model + - _self_ + +data: + data_path: + num_calibration_samples: 16 + max_seq_length: 9216 + tokenzier_args: + tokenizer_path: BAAI/Emu3-Gen/ + special_tokens_file: BAAI/Emu3-Gen/emu3_vision_tokens.txt + trust_remote_code: true + +compress_args: + quantization: + - algo: + gptq: + blocksize: 128 + percdamp: 0.01 + ignore: ["lm_head"] + targets: ["Linear"] + scheme: W4A16 + + diff --git a/examples/emu/conf/compress/emu3_model.yaml b/examples/emu/conf/compress/emu3_model.yaml new file mode 100644 index 000000000..16751891d --- /dev/null +++ b/examples/emu/conf/compress/emu3_model.yaml @@ -0,0 +1,19 @@ +system: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + use_flash_attn: True + sequence_parallel: True + logging: + log_interval: 1 + tensorboard_log_interval: 1 + wandb_project: "compress-emu3-7B" + wandb_exp_name: "compress-test-7B" + save_dir: + + +model: + model_cls: AutoModelForCausalLM + model_path: BAAI/Emu3-Gen/ + device_map: cuda:0 + trust_remote_code: true + torch_dtype: bfloat16 diff --git a/examples/emu/conf/config_compress.yaml b/examples/emu/conf/config_compress.yaml new file mode 100644 index 000000000..c4eff9695 --- /dev/null +++ b/examples/emu/conf/config_compress.yaml @@ -0,0 +1,23 @@ +defaults: + - _self_ + - compress: compress_emu3 + +experiment: + exp_name: emu3 + exp_dir: outputs/${experiment.exp_name} + task: + type: compress + entrypoint: flagscale/compress/compressor_emu3.py + runner: + hostfile: null + cmds: + before_start: source activate flagscale + envs: + CUDA_VISIBLE_DEVICES: 0 + CUDA_DEVICE_MAX_CONNECTIONS: 1 + +action: run + +hydra: + run: + dir: ${experiment.exp_dir}/hydra diff --git a/examples/llava_onevision/conf/compress/compress_llava_ov.yaml b/examples/llava_onevision/conf/compress/compress_llava_ov.yaml new file mode 100644 index 000000000..9ee62de68 --- /dev/null +++ b/examples/llava_onevision/conf/compress/compress_llava_ov.yaml @@ -0,0 +1,19 @@ +defaults: + - llava_ov_model + - _self_ + +data: + data_path: null + max_calib_data: null + max_seq_len: null + tokenzier_args: null + +compress_args: + quantization: + - algo: + targets: ["Linear"] + ignore: ["re:.*vision_model*", "re:multi_modal_projector*", "re:.*lm_head"] + scheme: FP8_DYNAMIC + + + diff --git a/examples/llava_onevision/conf/compress/compress_llava_ov_w4a16.yaml b/examples/llava_onevision/conf/compress/compress_llava_ov_w4a16.yaml new file mode 100644 index 000000000..afdb92144 --- /dev/null +++ b/examples/llava_onevision/conf/compress/compress_llava_ov_w4a16.yaml @@ -0,0 +1,21 @@ +defaults: + - llava_ov_model + - _self_ + +data: + data_path: + num_calibration_samples: 16 + max_seq_length: 8192 + tokenzier_args: null + +compress_args: + quantization: + - algo: + gptq: + blocksize: 128 + percdamp: 0.01 + ignore: ["re:.*vision_model*", "re:.*mm_projector*", "re:.*lm_head"] + targets: [Linear] + scheme: W4A16 + + diff --git a/examples/llava_onevision/conf/compress/llava_ov_model.yaml b/examples/llava_onevision/conf/compress/llava_ov_model.yaml new file mode 100644 index 000000000..ff1f2c1c3 --- /dev/null +++ b/examples/llava_onevision/conf/compress/llava_ov_model.yaml @@ -0,0 +1,19 @@ +system: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + use_flash_attn: True + sequence_parallel: True + logging: + log_interval: 1 + tensorboard_log_interval: 1 + wandb_project: "compress-llavaov-7B" + wandb_exp_name: "compress-test-7B" + save_dir: + + +model: + model_path: + device_map: cuda:0 + trust_remote_code: true + torch_dtype: bfloat16 + diff --git a/examples/llava_onevision/conf/config_compress.yaml b/examples/llava_onevision/conf/config_compress.yaml new file mode 100644 index 000000000..0cd5a76b4 --- /dev/null +++ b/examples/llava_onevision/conf/config_compress.yaml @@ -0,0 +1,23 @@ +defaults: + - _self_ + - compress: compress_llava_ov_w4a16 + +experiment: + exp_name: llava_ov + exp_dir: outputs/${experiment.exp_name} + task: + type: compress + entrypoint: flagscale/compress/compressor_llava_ov.py + runner: + hostfile: null + cmds: + before_start: source activate flagscale + envs: + CUDA_VISIBLE_DEVICES: 0 + CUDA_DEVICE_MAX_CONNECTIONS: 1 + +action: run + +hydra: + run: + dir: ${experiment.exp_dir}/hydra diff --git a/flagscale/compress/__init__.py b/flagscale/compress/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/flagscale/compress/adapter.py b/flagscale/compress/adapter.py new file mode 100644 index 000000000..7414159b4 --- /dev/null +++ b/flagscale/compress/adapter.py @@ -0,0 +1,150 @@ +import re +import os + +import torch +from llmcompressor.modifiers.quantization.gptq.utils.gptq_wrapper import GPTQWrapper +from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor +from llmcompressor.utils.fsdp.context import fix_fsdp_module_name +from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward +from compressed_tensors.quantization import ( + QuantizationScheme, + QuantizationStatus, + QuantizationConfig, + is_preset_scheme, + preset_name_to_scheme, + apply_quantization_config + ) +from compressed_tensors.quantization.lifecycle.apply import find_name_or_class_matches +from llmcompressor.modifiers.quantization.gptq.utils import get_output_error +from llmcompressor.utils.helpers import DisableKVCache +from compressed_tensors.quantization import ( + QuantizationScheme, + disable_quantization, + enable_quantization, +) +from llmcompressor.modifiers.quantization.calibration import initialize_observer, update_weight_zp_scale, freeze_module_quantization +from llmcompressor.transformers.sparsification.compressed_tensors_utils import modify_save_pretrained + +from flagscale.runner.runner_utils import logger + +__all__ = ["LLMCompressorAdapter"] + +QUANT_MAPPING_NAMES = { + "gptq": GPTQWrapper + } + +class LLMCompressorAdapter: + def __init__(self, model, scheme, targets, algo=None, ignore=None, dataset=None, num_calibration_steps=384): + self.model = model + modify_save_pretrained(self.model) + if algo is not None: + assert len(algo) == 1 + for k, v in algo.items(): + self.algo = k + self.algo_args = v + else: + self.algo = algo + self.scheme = scheme + self.ignore = ignore + self.targets = targets + self.wrapper_cls = None + self.layer_compressors_ = [] + self.num_calibration_steps = num_calibration_steps + self.dataset = dataset + + if (self.algo is None and is_preset_scheme(self.scheme)) or self.algo in list(QUANT_MAPPING_NAMES.keys()): + self.wrapper_cls = QUANT_MAPPING_NAMES[self.algo] if self.algo is not None else None + quant_config = self.init_quant_config() + + ### find ignore and target to quant, initialize module for quant + ### overwrite forward if quantization_enabled is Tue + apply_quantization_config(self.model, quant_config) + if self.wrapper_cls is None: + self.preprocess_weight() + else: + self.init_compressor() + if self.dataset is not None: + self.run_blockwise_calib_forward() + self.model.apply(freeze_module_quantization) + + + def init_quant_config(self): + if self.scheme is not None: + # takes precedence over config_groups + if isinstance(self.scheme, str) and is_preset_scheme(self.scheme): + # attach targets to scheme + self.scheme = {self.scheme: self.targets} + + self.config_groups = {} + for idx, key in enumerate(self.scheme.keys()): + if is_preset_scheme(key): + scheme = preset_name_to_scheme(key, self.scheme[key]) + else: + scheme = QuantizationScheme.model_validate( + {"targets": self.scheme[key], **self.scheme} + ) + + group_name = f"group_{idx}" + self.config_groups[group_name] = scheme + + if self.config_groups is None or len(self.config_groups) == 0: + default_quant_scheme = QuantizationScheme(targets=self.targets) + self.config_groups = {"group_0": default_quant_scheme} + logger.info( + f"No config groups were provided, using default {self.config_groups}" + ) + + return QuantizationConfig( + config_groups=self.config_groups, + kv_cache_scheme=None, ### TODO(lvmengsi): not support kv cache quant for now + quantization_status=QuantizationStatus.INITIALIZED, + ignore=self.ignore, + ) + + def init_compressor(self): + for name, layer in self.model.named_modules(): + name = fix_fsdp_module_name(name) + if name is None: + continue + try: + idx = int(name.split(".")[-1]) + except: + continue + + if matches := find_name_or_class_matches(name, layer, self.ignore): + continue + logger.info(f"prepare compressor for layer {name}") + compressor = LayerCompressor(self.wrapper_cls, self.model, layer, idx, name, self.algo_args) + self.layer_compressors_.append(compressor) + self.layer_compressors_[0].set_early_stop() + + def preprocess_weight(self): + for idx, (name, layer) in enumerate(self.model.named_modules()): + layer.apply(lambda module: initialize_observer(layer, base_name="weight")) + self.model.apply(update_weight_zp_scale) + + def add_hook(self): + pass + + @torch.no_grad() + def run_blockwise_calib_forward(self): + logger.info(f"start calibration") + self.model.apply(disable_quantization) + with DisableKVCache(self.model): + intermediates = run_calibration_forward( + self.model, self.dataset, num_calibration_steps=self.num_calibration_steps, mask_padding=False + ) + self.layer_compressors_[0].clear_early_stop() + + for idx, layer_compressor in enumerate(self.layer_compressors_): + logger.info(f"start calibration layer {layer_compressor.name}") + layer_compressor.pre_compress() + unquantized_outputs = layer_compressor.calibrate_layer(intermediates) + layer_compressor.compress() + layer_compressor.post_compress() + layer_compressor.revert_layer_wrappers() + quantized_outputs = layer_compressor.calibrate_layer(intermediates) + error = get_output_error(unquantized_outputs, quantized_outputs) + logger.info(f"Mean output error from quantization: {error:.3f}") + intermediates = quantized_outputs + self.model.apply(enable_quantization) \ No newline at end of file diff --git a/flagscale/compress/algo/__init__.py b/flagscale/compress/algo/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/flagscale/compress/algo/algo_base.py b/flagscale/compress/algo/algo_base.py new file mode 100644 index 000000000..f4bf4607a --- /dev/null +++ b/flagscale/compress/algo/algo_base.py @@ -0,0 +1,13 @@ +import os + +class BaseALGO: + def __init__(self, name): + self.name = name + self._observer = False + self._compress = False + + def preprocess_weight(self): + raise NotImplementedError + + def add_batch(self): + raise NotImplementedError diff --git a/flagscale/compress/combined_algo.py b/flagscale/compress/combined_algo.py new file mode 100644 index 000000000..69246b208 --- /dev/null +++ b/flagscale/compress/combined_algo.py @@ -0,0 +1,8 @@ +import os +import sys +import copy +import pdb + +def prepare_compress_methods(compress_cfg): + recipes = copy.deepcopy(compress_cfg) + return recipes \ No newline at end of file diff --git a/flagscale/compress/compressor.py b/flagscale/compress/compressor.py new file mode 100644 index 000000000..9ef35bd0b --- /dev/null +++ b/flagscale/compress/compressor.py @@ -0,0 +1,93 @@ +import os +import sys +import argparse +import yaml +import shutil +from omegaconf import OmegaConf + +import torch +from transformers import * + +from flagscale.compress.combined_algo import prepare_compress_methods +from flagscale.compress.adapter import LLMCompressorAdapter + +_g_ignore_fields = ["experiment", "action"] + +def prepare_config(config_path): + # Open the YAML file and convert it into a dictionary + with open(config_path, "r") as f: + yaml_dict = yaml.safe_load(f) + + # Extract valid config + for key in _g_ignore_fields: + yaml_dict.pop(key) + new_yaml_dict = {} + for k, v in yaml_dict.items(): + assert isinstance( + v, dict + ), f"Expected a dictionary for key {k}, but got {v} instead" + new_yaml_dict.update(v) + config = OmegaConf.create(new_yaml_dict) + return config + +def copy_rest_file(src_path, dst_path): + from huggingface_hub import hf_hub_download + from transformers import TRANSFORMERS_CACHE + from transformers.utils import http_user_agent + + if not os.path.exists(src_path): + user_agent = http_user_agent() + config_file_path = hf_hub_download( + repo_id=src_path, + filename="config.json", + cache_dir=TRANSFORMERS_CACHE, + force_download=False, + user_agent=user_agent, + ) + src_path = os.path.sep.join(config_file_path.split(os.path.sep)[:-1]) + + dst_path_files = os.listdir(dst_path) + for filename in os.listdir(src_path): + if not filename.endswith(".safetensors") and filename not in dst_path_files: + full_file_name = os.path.join(src_path, filename) + if (not filename.endswith(".md")) and os.path.isfile(full_file_name): + shutil.copy(full_file_name, dst_path) + elif os.path.isdir(full_file_name): + shutil.copytree(full_file_name, os.path.join(dst_path, filename)) + +def compress(cfg, model=None, dataset=None): + tokenizer = None + model_path = cfg.model.pop("model_path") + if cfg.data.tokenzier_args is not None: + tokenizer = AutoTokenizer.from_pretrained(cfg.data.tokenzier_args.pop("tokenizer_path"), **cfg.data.tokenzier_args) + if model is None: + model_cls = eval(cfg.model.pop("model_cls")) + model = model_cls.from_pretrained(model_path, **cfg.model) + assert isinstance(model, torch.nn.Module), f"model type {type(model)} error, please check it" + compress_args = cfg.compress_args + recipes = prepare_compress_methods(compress_args) + for method, recipe in recipes.items(): + for algo_args in recipe: + algo_args = OmegaConf.to_container(algo_args) + algo_args["dataset"] = dataset + algo_args["num_calibration_steps"] = cfg.data.get("max_seq_length", 384) + adapter = LLMCompressorAdapter(model=model, **algo_args) + ### modify model inplace + model = adapter.model + + # oneshot(model=model, dataset=dataset, recipe=recipe, tokenizer=tokenizer, output_dir=cfg.system.save_dir, max_seq_length=cfg.data.get("max_seq_length", 384), num_calibration_samples=cfg.data.get("num_calibration_samples", 512), splits="calibration") + model.save_pretrained(cfg.system.save_dir, save_compressed=True) + copy_rest_file(model_path, cfg.system.save_dir) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--config-path", + type=str, + required=True, + help="Path to the configuration YAML file", + ) + args = parser.parse_args() + cfg = prepare_config(args.config_path) + + compress(cfg) diff --git a/flagscale/compress/compressor_emu3.py b/flagscale/compress/compressor_emu3.py new file mode 100644 index 000000000..f8b6c9c8f --- /dev/null +++ b/flagscale/compress/compressor_emu3.py @@ -0,0 +1,62 @@ +import os +import sys +import argparse +import random +import yaml +import shutil +from omegaconf import OmegaConf +import torch +from megatron.core.datasets.indexed_dataset import IndexedDataset +from torch.utils.data import Dataset + +from flagscale.compress.compressor import compress, prepare_config + +class CusDataset(Dataset): + def __init__(self, ds): + self.ds = ds + self.indices = list(range(len(ds))) + self.column_names = {"input_ids": self} + self.calibration = self + + def __len__(self): + return len(self.indices) + + def __getitem__(self, idx): + index = self.indices[idx] + cur_ds = self.ds[index] + slice_idx = cur_ds.tolist().index(151643) ### mask padding + return {"input_ids": torch.tensor(cur_ds[:slice_idx]).unsqueeze(0), "attention_mask": torch.ones(slice_idx, dtype=torch.int64)} + + def shuffle(self): + random.shuffle(self.indices) + return self + + def select(self, samples_idx): + indices = [] + for idx in samples_idx: + indices.append(self.indices[idx]) + self.indices = indices + return self + + +def prepare_dataset(cfg): + print(cfg) + if cfg.data.data_path is None: + return None + ds = IndexedDataset(cfg.data.data_path, mmap=True) + dataset = CusDataset(ds) + return dataset + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--config-path", + type=str, + required=True, + help="Path to the configuration YAML file", + ) + args = parser.parse_args() + cfg = prepare_config(args.config_path) + dataset = prepare_dataset(cfg) + compress(cfg, dataset=dataset) diff --git a/flagscale/compress/compressor_llava_ov.py b/flagscale/compress/compressor_llava_ov.py new file mode 100644 index 000000000..5bcb6a84f --- /dev/null +++ b/flagscale/compress/compressor_llava_ov.py @@ -0,0 +1,120 @@ +import os +import sys +import argparse +import json +import random +import yaml +import shutil +import re +import ast +import warnings +from omegaconf import OmegaConf +import copy +import torch +from torch.utils.data import Dataset + +from flagscale.compress.compressor import compress, prepare_config +import transformers +from llava.model.builder import load_pretrained_model +from llava.train.train import make_supervised_data_module, DataArguments, LLaVATrainer +from adapter import LLMCompressorAdapter + +warnings.filterwarnings("ignore") + +class CusDataset(Dataset): + def __init__(self, ds): + self.ds = ds + self.indices = list(range(len(ds))) + self.column_names = self + self.calibration = self + + def __len__(self): + return len(self.indices) + + def __getitem__(self, idx): + index = self.indices[idx] + cur_ds = self.ds[index] + tmp_ = cur_ds.pop("image") + # cur_ds["modalities"] = [tmp_[0][2]] + cur_ds["image_sizes"] = [torch.tensor(tmp_[0][1])] + cur_ds["images"] = [tmp_[0][0].unsqueeze(0).to(torch.float16)] + del tmp_ + cur_ds.pop("id") + cur_ds["input_ids"] = cur_ds["input_ids"].unsqueeze(0) + cur_ds["labels"] = cur_ds["labels"].unsqueeze(0) + return cur_ds + + def shuffle(self): + random.shuffle(self.indices) + return self + + def select(self, samples_idx): + indices = [] + for idx in samples_idx: + indices.append(self.indices[idx]) + self.indices = indices + return self + + +def prepare_model(cfg): + origin_config = json.load(open(os.path.join(cfg.model["model_path"], "config.json"), "r")) + origin_vocab_size = origin_config["vocab_size"] + tokenizer, model, _, _ = load_pretrained_model(cfg.model["model_path"], None, cfg.model["model_name"], device_map=cfg.model["device_map"], attn_implementation="sdpa", multimodal=True) + model.resize_token_embeddings(origin_vocab_size) + return model, tokenizer + +def prepare_dataset(cfg, model, tokenizer): + if cfg.data.data_path is None: + return None + new_data_args = copy.deepcopy(cfg.data) + new_data_args.pop("num_calibration_samples") + new_data_args.pop("max_seq_length") + new_data_args.pop("tokenzier_args") + + parser = transformers.HfArgumentParser(DataArguments) + data_args = parser.parse_dict(new_data_args)[0] + vision_tower = model.get_vision_tower() + + data_args.image_processor = vision_tower.image_processor + data_args.is_multimodal = True + data_args.image_folder = "/" + data_args.mm_use_im_start_end = model.config.mm_use_im_start_end + + model.config.image_aspect_ratio = data_args.image_aspect_ratio + if data_args.image_grid_pinpoints is not None: + if isinstance(data_args.image_grid_pinpoints, str) and "x" in data_args.image_grid_pinpoints: + try: + patch_size = data_args.image_processor.size[0] + except Exception as e: + patch_size = data_args.image_processor.size["shortest_edge"] + + assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]" + # Use regex to extract the range from the input string + matches = re.findall(r"\((\d+)x(\d+)\)", data_args.image_grid_pinpoints) + range_start = tuple(map(int, matches[0])) + range_end = tuple(map(int, matches[-1])) + # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1]) + grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)] + # Multiply all elements by patch_size + data_args.image_grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints] + elif isinstance(data_args.image_grid_pinpoints, str): + data_args.image_grid_pinpoints = ast.literal_eval(data_args.image_grid_pinpoints) + dataset = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) + + ds = CusDataset(dataset["train_dataset"]) + return ds + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--config-path", + type=str, + required=True, + help="Path to the configuration YAML file", + ) + args = parser.parse_args() + cfg = prepare_config(args.config_path) + model, tokenizer = prepare_model(cfg) + dataset = prepare_dataset(cfg, model, tokenizer) + compress(cfg, dataset=dataset, model=model) diff --git a/flagscale/runner/runner_compress.py b/flagscale/runner/runner_compress.py new file mode 100644 index 000000000..10fa83d5b --- /dev/null +++ b/flagscale/runner/runner_compress.py @@ -0,0 +1,280 @@ +import os +import shlex +import time +from datetime import datetime + +import hydra +from hydra.core.hydra_config import HydraConfig +from omegaconf import DictConfig, OmegaConf + +from flagscale.runner.runner_base import JobStatus, RunnerBase +from flagscale.runner.runner_utils import ( + add_decive_extra_config, + flatten_dict_to_args, + get_free_port, + get_host_name_or_ip, + get_nnodes, + get_nproc_per_node, + logger, + parse_hostfile, + run_local_command, + run_scp_command, + run_ssh_command, +) + + +def _get_args_llmcompressor(config: DictConfig): + # see the following link for more details + # https://github.com/facebookresearch/hydra/discussions/2750 + # OmegaConf.set_struct(config, False) + + hydra_config = HydraConfig.get() + output_dir = hydra_config.runtime.output_dir + output_subdir = hydra_config.output_subdir + config_path = os.path.join(output_dir, f"{output_subdir}/config.yaml") + config_path = hydra.utils.to_absolute_path(config_path) + + args = [] + args.append(f"--config-path={config_path}") + + return args + + +def _update_config_compress(config: DictConfig): + exp_dir = os.path.abspath(config.experiment.exp_dir) + if not os.path.isdir(exp_dir): + os.makedirs(exp_dir) + assert os.path.isdir(exp_dir), f"Directory {exp_dir} does not exist." + + OmegaConf.set_struct(config, False) + config = config.compress.system + + wandb_dir = ( + os.path.abspath(config.logging.wandb_save_dir) + if config.logging.get("wandb_save_dir", None) + else os.path.join(exp_dir, "wandb") + ) + tensorboard_dir = ( + os.path.abspath(config.logging.tensorboard_dir) + if config.logging.get("tensorboard_dir", None) + else os.path.join(exp_dir, "tensorboard") + ) + log_dir = ( + os.path.abspath(config.logging.log_dir) + if config.logging.get("log_dir", None) + else os.path.join(exp_dir, "logs") + ) + + log_dir = os.path.join(exp_dir, f"compress_logs") + scripts_dir = os.path.join(log_dir, "scripts") + pids_dir = os.path.join(log_dir, "pids") + + config.logging.log_dir = log_dir + config.logging.scripts_dir = scripts_dir + config.logging.pids_dir = pids_dir + config.logging.tensorboard_dir = tensorboard_dir + config.logging.wandb_save_dir = wandb_dir + + OmegaConf.set_struct(config, True) + + +def _generate_run_script_compress( + config, host, node_rank, cmd, background=True, with_test=False +): + system_config = config.compress.system + logging_config = config.compress.system.logging + + no_shared_fs = config.experiment.runner.get("no_shared_fs", False) + if no_shared_fs: + host_output_file = os.path.join(logging_config.log_dir, f"host.output") + else: + host_output_file = os.path.join( + logging_config.log_dir, f"host_{node_rank}_{host}.output" + ) + host_run_script_file = os.path.join( + logging_config.scripts_dir, f"host_{node_rank}_{host}_run.sh" + ) + host_pid_file = os.path.join( + logging_config.pids_dir, f"host_{node_rank}_{host}.pid" + ) + + os.makedirs(logging_config.scripts_dir, exist_ok=True) + + root_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) + compress_dir = os.path.join(root_dir, "compress") + ### set megatron dir for dataset + megtron_dir = os.path.join(root_dir, "megatron") + cmds_config = config.experiment.get("cmds", None) + if cmds_config: + before_start = cmds_config.get("before_start", "") + else: + before_start = "" + with open(host_run_script_file, "w") as f: + f.write("#!/bin/bash\n\n") + f.write(f"{before_start}\n") + f.write(f"mkdir -p {system_config.save_dir}\n") + f.write(f"mkdir -p {system_config.logging.log_dir}\n") + f.write(f"mkdir -p {system_config.logging.pids_dir}\n") + f.write(f"mkdir -p {system_config.logging.tensorboard_dir}\n") + f.write(f"mkdir -p {system_config.logging.wandb_save_dir}\n") + f.write(f"\n") + f.write(f"cd {root_dir}\n") + f.write(f"\n") + f.write(f"export PYTHONPATH={compress_dir}:{megtron_dir}:{root_dir}\n") + f.write(f"\n") + f.write(f'cmd="{cmd}"\n') + f.write(f"\n") + if with_test: + f.write(f'bash -c "$cmd; sync" \n') + else: + # TODO: need a option to control whether to append or overwrite the output file + # Now, it always appends to the output file + if background: + f.write( + f'nohup bash -c "$cmd; sync" >> {host_output_file} 2>&1 & echo $! > {host_pid_file}\n' + ) + else: + f.write(f'bash -c "$cmd; sync" >> {host_output_file} 2>&1\n') + f.write("\n") + f.flush() + os.fsync(f.fileno()) + os.chmod(host_run_script_file, 0o755) + + return host_run_script_file + + +class SSHCompressRunner(RunnerBase): + def __init__(self, config: DictConfig): + super().__init__(config) + self.task_type = getattr(self.config.experiment.task, "type", None) + assert self.task_type == "compress", f"Unsupported task type: {self.task_type}" + self._prepare() + + def _prepare(self): + _update_config_compress(self.config) + self.user_args = _get_args_llmcompressor(self.config) + self.rdzv_id = datetime.now().strftime("%Y%m%d_%H%M%S.%f") + self.user_envs = self.config.experiment.get("envs", {}) + self.cur_envs = None # current node envs + self.user_script = self.config.experiment.task.entrypoint + self.resources = parse_hostfile( + self.config.experiment.runner.get("hostfile", None) + ) + logger.info("\n************** configuration **************") + logger.info(f"\n{OmegaConf.to_yaml(self.config)}") + + def _run_each( + self, + host, + master_addr, + master_port, + nnodes, + node_rank, + nproc_per_node, + with_test=False, + dryrun=False, + ): + export_cmd = [] + for k, v in self.user_envs.items(): + export_cmd += [f"{k}={v}"] + + cmd = shlex.join(export_cmd + ["python"] + [self.user_script] + self.user_args) + + logging_config = self.config.compress.system.logging + host_run_script_file = _generate_run_script_compress( + self.config, host, node_rank, cmd, background=True, with_test=with_test + ) + + if host != "localhost": + ssh_port = self.config.experiment.runner.get("ssh_port", 22) + # Step 1: make sure the scripts_dir exists on the remote host + run_ssh_command( + host, f"mkdir -p {logging_config.scripts_dir}", ssh_port, dryrun + ) + + # Step 2: copy the host_run_script_file to the remote host + no_shared_fs = self.config.experiment.runner.get("no_shared_fs", False) + if no_shared_fs: + run_scp_command( + host, + host_run_script_file, + logging_config.scripts_dir, + ssh_port, + dryrun, + ) + + # Step 3: run the host_run_script_file on the remote host + run_ssh_command(host, f"bash {host_run_script_file}", ssh_port, dryrun) + else: + run_local_command(f"bash {host_run_script_file}", dryrun) + + def run(self, with_test=False, dryrun=False): + num_visible_devices = None + visible_devices = self.user_envs.get("CUDA_VISIBLE_DEVICES", None) + if visible_devices is not None and isinstance(visible_devices, str): + visible_devices = visible_devices.split(",") + num_visible_devices = len(visible_devices) + + runner_config = self.config.experiment.runner + + # If hostfile is provided, use the resources from the hostfile + if self.resources is not None: + nnodes_from_hostfile = len(self.resources.keys()) + nnodes_from_args = runner_config.get("nnodes", None) + nnodes = get_nnodes(nnodes_from_hostfile, nnodes_from_args) + available_ip = list(self.resources.keys())[0] + available_port = get_free_port() + for node_rank, (host, resource_info) in enumerate(self.resources.items()): + if node_rank >= nnodes: + break + nproc_from_hostfile = resource_info["slots"] + nproc_from_args = runner_config.get("nproc_per_node", None) + nproc_per_node = get_nproc_per_node( + nproc_from_hostfile, nproc_from_args, num_visible_devices + ) + master_addr = runner_config.get("master_addr", available_ip) + master_port = runner_config.get("master_port", available_port) + self._run_each( + host, + master_addr, + master_port, + nnodes, + node_rank, + nproc_per_node, + with_test=with_test, + dryrun=dryrun, + ) + else: + # If hostfile is not provided, run the job on localhost + nproc_from_args = runner_config.get("nproc_per_node", None) + nproc_per_node = get_nproc_per_node( + None, nproc_from_args, num_visible_devices + ) + available_addr = runner_config.get("master_addr", "localhost") + available_port = runner_config.get("master_port", get_free_port()) + self._run_each( + "localhost", + available_addr, + available_port, + 1, + 0, + nproc_per_node, + with_test=with_test, + dryrun=dryrun, + ) + + def stop(self): + if self.resources is None: + self._stop_each("localhost", 0) + return + + nnodes = get_nnodes( + len(self.resources), self.config.experiment.runner.get("nnodes", None) + ) + + for node_rank, (host, _) in enumerate(self.resources.items()): + if node_rank >= nnodes: + break + self._stop_each(host, node_rank) diff --git a/requirements/compress/requirements-dev.txt b/requirements/compress/requirements-dev.txt new file mode 100644 index 000000000..3fb8eb6df --- /dev/null +++ b/requirements/compress/requirements-dev.txt @@ -0,0 +1,2 @@ +llmcompressor +compressed-tensors-nightly diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index 1e35124b8..fff0156aa 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -10,3 +10,6 @@ # flagscale -r flagscale/requirements-common.txt -r flagscale/requirements-dev.txt + +# compress +-r compress/requirements-dev.txt diff --git a/run.py b/run.py index 472f072ac..aafd0f940 100644 --- a/run.py +++ b/run.py @@ -4,6 +4,7 @@ from flagscale.runner.runner_train import SSHTrainRunner, CloudTrainRunner from flagscale.runner.runner_inference import SSHInferenceRunner from flagscale.runner.runner_serve import SSHServeRunner +from flagscale.runner.runner_compress import SSHCompressRunner @hydra.main(version_base=None, config_name="config") @@ -54,6 +55,12 @@ def main(config: DictConfig) -> None: runner.run() elif config.action == "test": runner.run(with_test=True) + elif task_type == "compress": + runner = SSHCompressRunner(config) + if config.action == "run": + runner.run() + elif config.action == "dryrun": + runner.run(dryrun=True) elif config.action == "stop": runner.stop() else: diff --git a/tests/unit_tests/compressor/test_adapter.py b/tests/unit_tests/compressor/test_adapter.py new file mode 100644 index 000000000..2fc7d8ad1 --- /dev/null +++ b/tests/unit_tests/compressor/test_adapter.py @@ -0,0 +1,60 @@ +import os +import torch +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, AutoImageProcessor +from flagscale.compress.adapter import LLMCompressorAdapter +from flagscale.inference.processing_emu3 import ( + CachedPrefixConstrainedLogitsProcessor, + Emu3Processor, +) + +def test_llmcompressor_adpter_without_dataset(): + model_path = "BAAI/Emu3-Gen" + model = AutoModelForCausalLM.from_pretrained( + model_path, + device_map="cuda:0", + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + trust_remote_code=True, + ) + + quant_args = {"targets": ['Linear'], "scheme": 'FP8_DYNAMIC', "ignore":['lm_head']} + adapter = LLMCompressorAdapter(model=model, **quant_args) + adapter.model.save_pretrained("test_output", save_compressed=True) + os.remove("test_output") + +def test_llmcompressor_adpter_with_dataset(): + EMU_HUB = "BAAI/Emu3-Gen" + VQ_HUB = "BAAI/Emu3-VisionTokenizer" + # prepare model and processor + model = AutoModelForCausalLM.from_pretrained( + EMU_HUB, + device_map="cuda:0", + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + trust_remote_code=True, + ) + model.eval() + tokenizer = AutoTokenizer.from_pretrained(EMU_HUB, trust_remote_code=True, padding_side="left") + image_processor = AutoImageProcessor.from_pretrained(VQ_HUB, trust_remote_code=True) + image_tokenizer = AutoModel.from_pretrained(VQ_HUB, device_map="cuda:0", trust_remote_code=True).eval() + processor = Emu3Processor(image_processor, image_tokenizer, tokenizer) + + # prepare input + POSITIVE_PROMPT = " masterpiece, film grained, best quality." + NEGATIVE_PROMPT = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry." + prompt = ["a portrait of young girl.", "a shiba inu"] + prompt = [p + POSITIVE_PROMPT for p in prompt] + kwargs = dict( + mode='G', + ratio=["1:1", "16:9"], + image_area=model.config.image_area, + return_tensors="pt", + padding="longest", + ) + pos_inputs = processor(text=prompt, **kwargs) + neg_inputs = processor(text=[NEGATIVE_PROMPT] * len(prompt), **kwargs) + quant_args = {"targets": ['Linear'], "scheme": 'W4A16', "ignore":['lm_head'], "algo": {"gptq": {"blocksize": 128, "percdamp": 0.01}}, "dataset": pos_inputs.input_ids.to("cuda:0")} + adapter = LLMCompressorAdapter(model=model, **quant_args) + adapter.model.save_pretrained("test_output", save_compressed=True) + os.remove("test_output") + diff --git a/tests/unit_tests/compressor/test_compressor.py b/tests/unit_tests/compressor/test_compressor.py new file mode 100644 index 000000000..452061a52 --- /dev/null +++ b/tests/unit_tests/compressor/test_compressor.py @@ -0,0 +1,6 @@ +from flagscale.compress.compressor import compress, prepare_config + +def test_config(): + test_config_path = "test_config.yaml" + cfg = prepare_config() + compress(cfg) diff --git a/tests/unit_tests/compressor/test_config.yaml b/tests/unit_tests/compressor/test_config.yaml new file mode 100644 index 000000000..6203854fd --- /dev/null +++ b/tests/unit_tests/compressor/test_config.yaml @@ -0,0 +1,48 @@ +experiment: + exp_name: emu3 + exp_dir: outputs/${experiment.exp_name} + task: + type: compress + entrypoint: flagscale/compress/compressor_emu3.py + runner: + hostfile: null + cmds: + before_start: source activate flagscale + envs: + CUDA_VISIBLE_DEVICES: 0 + CUDA_DEVICE_MAX_CONNECTIONS: 1 +action: run +compress: + system: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + use_flash_attn: true + sequence_parallel: true + logging: + log_interval: 1 + tensorboard_log_interval: 1 + wandb_project: compress-emu3-7B + wandb_exp_name: compress-test-7B + save_dir: null + model: + model_cls: AutoModelForCausalLM + model_path: BAAI/Emu3-Gen + device_map: cuda:0 + trust_remote_code: true + torch_dtype: bfloat16 + data: + data_path: null + max_calib_data: null + max_seq_len: null + tokenzier_args: + tokenizer_path: BAAI/Emu3-Gen/ + special_tokens_file: BAAI/Emu3-Gen/emu3_vision_tokens.txt + trust_remote_code: true + compress_args: + quantization: + - algo: null + targets: + - Linear + ignore: + - lm_head + scheme: FP8_DYNAMIC From 6e4a5b91151d7d6aaa62dc0fc3d0a9fddfa1e72a Mon Sep 17 00:00:00 2001 From: Xianduo Li <30922914+lxd-cumt@users.noreply.github.com> Date: Wed, 8 Jan 2025 10:39:59 +0800 Subject: [PATCH 2/5] [Core] Support sigmoid router for aux_loss in moe (#310) Support sigmoid router in moe similar to DeepSeek V3 - add a configuration str in TransformerConfig - support using sigmoid+normalization to compute prob scores The implementation in this PR, is similar to that in DeepSeek V3, as shown in the figure below. Scores are first calculated using a sigmoid for x, and then the topk function is used to select the k largest scores, finally the k scores are averaged. Currently, sigmoid router is supported for two load balancing types: auxiliary loss and sequence auxiliary loss. We will support auxiliary loss free load balancing in the future. 20250107-142500 references to: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/DeepSeek_V3.pdf --- .../core/transformer/moe/moe_utils.py | 23 +++++++++++++++++-- .../megatron/core/transformer/moe/router.py | 14 +++++++++-- .../core/transformer/transformer_config.py | 3 +++ megatron/megatron/training/arguments.py | 4 ++++ 4 files changed, 40 insertions(+), 4 deletions(-) diff --git a/megatron/megatron/core/transformer/moe/moe_utils.py b/megatron/megatron/core/transformer/moe/moe_utils.py index ac3357ed1..49e96bcd0 100644 --- a/megatron/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/megatron/core/transformer/moe/moe_utils.py @@ -94,6 +94,20 @@ def sequence_load_balancing_loss_func( return seq_aux_loss +def score_function( + input: torch.Tensor, + score_function_type: str = "softmax", + target_dtype: torch.dtype = None, +): + if score_function_type == "softmax": + scores = torch.softmax(input, dim=-1, dtype=torch.float32) + if target_dtype: + scores = scores.type(target_dtype) + elif score_function_type == "sigmoid": + scores = input.sigmoid() + else: + raise ValueError(f"Unsupported MoE routing score function type: {score_function_type}") + return scores def z_loss_func(logits, z_loss_coeff): """Encourages the router's logits to remain small to enhance stability. @@ -323,6 +337,7 @@ def topk_softmax_with_capacity( moe_router_topk_limited_devices: int = None, moe_router_topk_scaling_factor: float = None, deterministic_mode: bool = False, + score_function_type: str = "softmax", ): """Apply capacity and padding to the top-k selection. Args: @@ -355,7 +370,7 @@ def topk_softmax_with_capacity( num_experts = logits.shape[1] if use_pre_softmax: # Pre softmax - scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) + scores = score_function(logits, score_function_type, logits.dtype) if moe_router_topk_limited_devices: probs, top_indices = device_limited_topk( @@ -382,7 +397,11 @@ def topk_softmax_with_capacity( ) else: scores, top_indices = torch.topk(logits, k=topk, dim=1) - probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits) + probs = score_function(scores, score_function_type, logits.dtype) + + if score_function_type == "sigmoid": + tmp = probs.sum(dim=-1, keepdim=True) + probs = probs / tmp # TODO Try using element-wise operations instead of scatter? topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs) diff --git a/megatron/megatron/core/transformer/moe/router.py b/megatron/megatron/core/transformer/moe/router.py index 82d1029a5..a628e34a5 100644 --- a/megatron/megatron/core/transformer/moe/router.py +++ b/megatron/megatron/core/transformer/moe/router.py @@ -17,6 +17,7 @@ switch_load_balancing_loss_func, topk_softmax_with_capacity, z_loss_func, + score_function, ) from megatron.core.transformer.transformer_config import TransformerConfig @@ -102,6 +103,7 @@ def __init__(self, config: TransformerConfig) -> None: super().__init__(config=config) self.topk = self.config.moe_router_topk self.routing_type = self.config.moe_router_load_balancing_type + self.score_function_type = self.config.moe_router_score_function_type self.input_jitter = None def sinkhorn_load_balancing(self, logits: torch.Tensor): @@ -157,11 +159,15 @@ def aux_loss_load_balancing(self, logits: torch.Tensor): moe_router_topk_limited_devices=self.config.moe_router_topk_limited_devices, moe_router_topk_scaling_factor=self.config.moe_router_topk_scaling_factor, deterministic_mode=self.config.deterministic_mode, + score_function_type=self.score_function_type, ) if self.training: # Apply load balancing loss - scores = torch.softmax(logits, dim=-1, dtype=torch.float32) + scores = score_function(logits, self.score_function_type) + if self.score_function_type == "sigmoid": + tmp = scores.sum(dim=-1, keepdim=True) + scores = scores / tmp aux_loss_func = partial( switch_load_balancing_loss_func, probs=scores, @@ -186,10 +192,14 @@ def seq_aux_loss_load_balancing(self, logits: torch.Tensor, bsz: int, seq_length moe_router_topk_limited_devices=self.config.moe_router_topk_limited_devices, moe_router_topk_scaling_factor=self.config.moe_router_topk_scaling_factor, deterministic_mode=self.config.deterministic_mode, + score_function_type=self.score_function_type, ) if self.training: - scores = torch.softmax(logits, dim=-1, dtype=torch.float32) + scores = score_function(logits, self.score_function_type) + if self.score_function_type == "sigmoid": + tmp = scores.sum(dim=-1, keepdim=True) + scores = scores / tmp aux_loss_func = partial( sequence_load_balancing_loss_func, probs=scores, diff --git a/megatron/megatron/core/transformer/transformer_config.py b/megatron/megatron/core/transformer/transformer_config.py index 7f1457eb0..206f89b96 100644 --- a/megatron/megatron/core/transformer/transformer_config.py +++ b/megatron/megatron/core/transformer/transformer_config.py @@ -283,6 +283,9 @@ class TransformerConfig(ModelParallelConfig): which computes the loss for each individual sample; "sinkhorn" corresponds to the balancing algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss".""" + moe_router_score_function_type: str = "softmax" + """Determines the score function type for the router, currently support two load balancing type: "aux_loss" and "seq_aux_loss".""" + moe_router_topk: int = 2 """Number of experts to route to for each token.""" diff --git a/megatron/megatron/training/arguments.py b/megatron/megatron/training/arguments.py index 5692f5581..b5d9dd8ed 100644 --- a/megatron/megatron/training/arguments.py +++ b/megatron/megatron/training/arguments.py @@ -2230,6 +2230,10 @@ def _add_moe_args(parser): choices=['aux_loss', 'seq_aux_loss', 'sinkhorn', 'none'], default='aux_loss', help='Determines the load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss used in GShard and SwitchTransformer; "seq_aux_loss" corresponds to the load balancing loss used in DeepSeekV2, which computes the loss for each individual sample; "sinkhorn" corresponds to the balancing algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss".') + group.add_argument('--moe-router-score-function-type', type=str, + choices=['softmax', 'sigmoid'], + default='softmax', + help='Determines the score function type for the router, currently support two load balancing type: "aux_loss" and "seq_aux_loss".') group.add_argument('--moe-router-topk', type=int, default=2, help='Number of experts to route to for each token. The default is 2.') group.add_argument('--moe-router-pre-softmax', action='store_true', From b338029464c9fe031c34cbdd2c12d921e17da74e Mon Sep 17 00:00:00 2001 From: Xianduo Li <30922914+lxd-cumt@users.noreply.github.com> Date: Fri, 10 Jan 2025 14:39:54 +0800 Subject: [PATCH 3/5] Add unit test for moe sigmoid router (#311) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a unit test for moe sigmoid router, including tests for two load balancing strategy: aux_loss and seq_aux_loss. ref pr #310 CI已触发 img_v3_02id_1072c06b-3da9-4951-8b6a-78d862ab651g --- tests/unit_tests/test_sigmoid_aux_loss.py | 253 ++++++++++++++++++++++ 1 file changed, 253 insertions(+) create mode 100644 tests/unit_tests/test_sigmoid_aux_loss.py diff --git a/tests/unit_tests/test_sigmoid_aux_loss.py b/tests/unit_tests/test_sigmoid_aux_loss.py new file mode 100644 index 000000000..c3111be6c --- /dev/null +++ b/tests/unit_tests/test_sigmoid_aux_loss.py @@ -0,0 +1,253 @@ +import pytest +import torch + +from megatron.core import parallel_state +from megatron.training.initialize import _set_random_seed +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.moe.moe_utils import clear_aux_losses_tracker +from tests.unit_tests.test_utilities import Utils +from tests.unit_tests.transformer.moe.test_token_dispatcher import MoEModelTestContainer + + +class AuxlossTestContainer(MoEModelTestContainer): + def __init__( + self, + tp_size, + ep_size, + pp_size, + cp_size=1, + moe_tp_size=None, + data_parallel_random_init=False, + num_moe_experts=8, + moe_router_topk=2, + moe_router_load_balancing_type="aux_loss", + moe_token_dispatcher_type="alltoall", + moe_expert_capacity_factor=None, + moe_pad_expert_input_to_capacity=False, + moe_aux_loss_coeff=0.1, + **kwargs, + ): + self.num_local_experts = num_moe_experts // ep_size + if moe_tp_size is None: + moe_tp_size = tp_size + Utils.initialize_model_parallel( + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=pp_size, + expert_model_parallel_size=ep_size, + context_parallel_size=cp_size, + expert_tensor_parallel_size=moe_tp_size, + ) + _set_random_seed(seed_=123, data_parallel_random_init=data_parallel_random_init) + local_expert_indices_offset = ( + parallel_state.get_expert_model_parallel_rank() * self.num_local_experts + ) + self.local_expert_indices = [ + local_expert_indices_offset + i for i in range(self.num_local_experts) + ] + self.config = TransformerConfig( + tensor_model_parallel_size=tp_size, + expert_model_parallel_size=ep_size, + pipeline_model_parallel_size=pp_size, + context_parallel_size=cp_size, + expert_tensor_parallel_size=moe_tp_size, + moe_router_topk=moe_router_topk, + num_moe_experts=num_moe_experts, + moe_router_load_balancing_type=moe_router_load_balancing_type, + moe_token_dispatcher_type=moe_token_dispatcher_type, + moe_expert_capacity_factor=moe_expert_capacity_factor, + moe_pad_expert_input_to_capacity=moe_pad_expert_input_to_capacity, + moe_aux_loss_coeff=moe_aux_loss_coeff, + num_layers=1, + moe_grouped_gemm=kwargs.get("moe_grouped_gemm", False), + hidden_size=kwargs.get("hidden_size", 16), + num_attention_heads=kwargs.get("num_attention_heads", 8), + use_cpu_initialization=kwargs.get("use_cpu_initialization", True), + sequence_parallel=tp_size > 1, + add_bias_linear=kwargs.get("add_bias_linear", False), + moe_router_score_function_type=kwargs.get("moe_router_score_function_type", "softmax"), + ) + + # init moe layer + self.moe_layer = self.new_moe_layer() + + def partition_input(self, input): + partitioned_input = input.chunk( + parallel_state.get_tensor_and_context_parallel_world_size(), dim=1 + )[parallel_state.get_tensor_and_context_parallel_rank()] + output = partitioned_input.clone().detach() + output.requires_grad = True + return output + + @pytest.mark.internal + def aux_loss_test(self, input, baseline_grad): + partitioned_input = self.partition_input(input) + moe_layer = self.moe_layer + probs, indices = moe_layer.router(partitioned_input) + probs.sum().mul_(0).backward() + aux_loss_grad = partitioned_input.grad + torch.distributed.barrier() + ans = self.partition_input(baseline_grad) + assert torch.allclose(aux_loss_grad, ans), f"Diff: {(aux_loss_grad/ans).mean()}" + loss = parallel_state.get_moe_layer_wise_logging_tracker()['load_balancing_loss'] + clear_aux_losses_tracker() + + +class TestSigmoidAuxLoss: + def setup_method(self, method): + baseline_container = AuxlossTestContainer( + tp_size=1, + ep_size=1, + pp_size=1, + cp_size=1, + num_moe_experts=8, + moe_router_topk=2, + moe_router_load_balancing_type="aux_loss", + moe_token_dispatcher_type="alltoall", + moe_aux_loss_coeff=0.1, + moe_router_score_function_type="sigmoid", + ) + moe_layer = baseline_container.moe_layer + self.input = torch.randn((32, 8, moe_layer.config.hidden_size)).cuda() + self.input.requires_grad = True + probs, indices = moe_layer.router(self.input) + probs.sum().mul_(0).backward() # zero out the main gradients + self.baseline_grad = self.input.grad + self.input.grad = None + clear_aux_losses_tracker() + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.internal + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal + @pytest.mark.parametrize( + "tp_size,ep_size,cp_size", [(8, 1, 1), (4, 2, 1), (1, 1, 8), (2, 1, 4), (2, 2, 2)] + ) + def test_allgather_dispatcher(self, tp_size, ep_size, cp_size): + container = AuxlossTestContainer( + tp_size=tp_size, + ep_size=ep_size, + pp_size=1, + cp_size=cp_size, + num_moe_experts=8, + moe_router_topk=2, + moe_router_load_balancing_type="aux_loss", + moe_token_dispatcher_type="allgather", + moe_aux_loss_coeff=0.1, + moe_router_score_function_type="sigmoid", + ) + container.aux_loss_test(self.input, self.baseline_grad) + + @pytest.mark.internal + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal + @pytest.mark.parametrize( + "tp_size,ep_size,cp_size", [(8, 1, 1), (4, 2, 1), (1, 1, 8), (2, 1, 4), (2, 2, 2)] + ) + def test_a2a_dispatcher(self, tp_size, ep_size, cp_size): + container = AuxlossTestContainer( + tp_size=tp_size, + ep_size=ep_size, + pp_size=1, + cp_size=cp_size, + num_moe_experts=8, + moe_router_topk=2, + moe_router_load_balancing_type="aux_loss", + moe_token_dispatcher_type="alltoall", + moe_aux_loss_coeff=0.1, + moe_router_score_function_type="sigmoid", + ) + container.aux_loss_test(self.input, self.baseline_grad) + + +class TestSigmoidSeqAuxLoss: + def setup_method(self, method): + baseline_container = AuxlossTestContainer( + tp_size=1, + ep_size=1, + pp_size=1, + cp_size=1, + num_moe_experts=8, + moe_router_topk=2, + moe_router_load_balancing_type="seq_aux_loss", + moe_token_dispatcher_type="alltoall", + moe_aux_loss_coeff=0.1, + moe_router_score_function_type="sigmoid", + ) + moe_layer = baseline_container.moe_layer + self.input = torch.randn((32, 8, moe_layer.config.hidden_size)).cuda() + self.input.requires_grad = True + probs, indices = moe_layer.router(self.input) + probs.sum().mul_(0).backward() # zero out the main gradients + self.baseline_grad = self.input.grad + self.input.grad = None + clear_aux_losses_tracker() + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.internal + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal + @pytest.mark.parametrize("tp_size,ep_size,cp_size", [(1, 8, 1)]) + def test_a2a_dispatcher(self, tp_size, ep_size, cp_size): + container = AuxlossTestContainer( + tp_size=tp_size, + ep_size=ep_size, + pp_size=1, + cp_size=cp_size, + num_moe_experts=8, + moe_router_topk=2, + moe_router_load_balancing_type="seq_aux_loss", + moe_token_dispatcher_type="alltoall", + moe_aux_loss_coeff=0.1, + moe_router_score_function_type="sigmoid", + ) + container.aux_loss_test(self.input, self.baseline_grad) + + +class TestSoftmaxSeqAuxLoss: + def setup_method(self, method): + baseline_container = AuxlossTestContainer( + tp_size=1, + ep_size=1, + pp_size=1, + cp_size=1, + num_moe_experts=8, + moe_router_topk=2, + moe_router_load_balancing_type="seq_aux_loss", + moe_token_dispatcher_type="alltoall", + moe_aux_loss_coeff=0.1, + moe_router_score_function_type="softmax", + ) + moe_layer = baseline_container.moe_layer + self.input = torch.randn((32, 8, moe_layer.config.hidden_size)).cuda() + self.input.requires_grad = True + probs, indices = moe_layer.router(self.input) + probs.sum().mul_(0).backward() # zero out the main gradients + self.baseline_grad = self.input.grad + self.input.grad = None + clear_aux_losses_tracker() + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.internal + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal + @pytest.mark.parametrize("tp_size,ep_size,cp_size", [(1, 8, 1)]) + def test_a2a_dispatcher(self, tp_size, ep_size, cp_size): + container = AuxlossTestContainer( + tp_size=tp_size, + ep_size=ep_size, + pp_size=1, + cp_size=cp_size, + num_moe_experts=8, + moe_router_topk=2, + moe_router_load_balancing_type="seq_aux_loss", + moe_token_dispatcher_type="alltoall", + moe_aux_loss_coeff=0.1, + moe_router_score_function_type="softmax", + ) + container.aux_loss_test(self.input, self.baseline_grad) \ No newline at end of file From 020f748214c20ee0bee3fab7b409d3997e8abb38 Mon Sep 17 00:00:00 2001 From: zhaoyinglia <86812880+zhaoyinglia@users.noreply.github.com> Date: Tue, 14 Jan 2025 14:00:24 +0800 Subject: [PATCH 4/5] [Core] Add logging for activated parameters (#313) --- .../training/theoretical_memory_usage.py | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/megatron/megatron/training/theoretical_memory_usage.py b/megatron/megatron/training/theoretical_memory_usage.py index f9b75031a..bfbef7b5b 100644 --- a/megatron/megatron/training/theoretical_memory_usage.py +++ b/megatron/megatron/training/theoretical_memory_usage.py @@ -8,6 +8,56 @@ NUM_BYTES_IN_MEGABYTE = 1024 * 1024 +def compute_activated_weight_number(args, verbose=False): + if args.num_experts is None: + return + # Attention projection size. + query_projection_size = args.kv_channels * args.num_attention_heads + query_projection_to_hidden_size_ratio = query_projection_size / args.hidden_size + # Group Query Attention. + if not args.group_query_attention: + args.num_query_groups = args.num_attention_heads + # MoE. + # NOTE(zhaoyingli): We only compute the number of activated parameters by topk routing. + num_experts = args.moe_router_topk + gated_linear_multiplier = 3 / 2 if args.swiglu else 1 + num_parameters_in_transformer_layers = ( + 2 + * args.num_layers + * args.hidden_size + * args.hidden_size + * ( + # Attention. + ( + (1 + (args.num_query_groups / args.num_attention_heads)) + * query_projection_to_hidden_size_ratio + ) + # MLP. + + ((args.ffn_hidden_size / args.hidden_size) * num_experts * gated_linear_multiplier) + # Transformer layernorms. + + (2 / args.hidden_size) + # Final layernorm. + + (1 / (args.num_layers * args.hidden_size)) + ) + ) + embedding_size = args.hidden_size * args.padded_vocab_size + if args.untie_embeddings_and_output_weights: + num_parameters_in_embedding_layers = 2 * embedding_size + else: + num_parameters_in_embedding_layers = embedding_size + num_total_parameters = num_parameters_in_transformer_layers + num_parameters_in_embedding_layers + if verbose: + print( + f"Number of activated parameters in transformer layers in billions: " + f"{num_parameters_in_transformer_layers / 10**9: .2f}" + ) + print( + f"Number of activated parameters in embedding layers in billions: " + f"{num_parameters_in_embedding_layers / 10**9:.2f}" + ) + print(f"Total number of activated parameters in billions: {num_total_parameters / 10**9:.2f}") + + def compute_weight_and_optimizer_memory(args, verbose=False): # Attention projection size. query_projection_size = args.kv_channels * args.num_attention_heads @@ -164,6 +214,8 @@ def compute_activation_memory(args, num_microbatches, verbose=False): def report_theoretical_memory(args, num_microbatches=None, verbose=False): + compute_activated_weight_number(args, verbose=verbose) + weight_and_optimizer_memory = ( compute_weight_and_optimizer_memory(args, verbose=verbose) / NUM_BYTES_IN_MEGABYTE ) From c914e4ce342a66a49438d998d3ad9f5415c2c250 Mon Sep 17 00:00:00 2001 From: lzydev <1528794076@qq.com> Date: Tue, 21 Jan 2025 10:03:33 +0800 Subject: [PATCH 5/5] [Core] Support to skip samples according to spiky loss (#315) Support to skip samples according to spiky loss. There are two methods to skip samples: 1. You can manually set the skipping range ``` skip_samples_range: [20, 30] ``` 2. Automatically detect the spiky loss and skip the samples ``` auto_skip_spiky_loss: true spiky_loss_threshold: 0.25 ``` --------- Co-authored-by: lizhiyu --- flagscale/train/__init__.py | 2 + flagscale/train/global_vars.py | 15 +++++- flagscale/train/spiky_loss.py | 52 +++++++++++++++++++ flagscale/train/train.py | 43 ++++++++++++++- megatron/megatron/training/arguments.py | 15 ++++++ megatron/megatron/training/initialize.py | 5 +- .../conf/train/tp2pp1_tp4pp1_tp2pp1.yaml | 2 + tests/unit_tests/test_spiky_loss_detector.py | 35 +++++++++++++ 8 files changed, 166 insertions(+), 3 deletions(-) create mode 100644 flagscale/train/spiky_loss.py create mode 100644 tests/unit_tests/test_spiky_loss_detector.py diff --git a/flagscale/train/__init__.py b/flagscale/train/__init__.py index a31eb1491..3ce739be5 100644 --- a/flagscale/train/__init__.py +++ b/flagscale/train/__init__.py @@ -4,4 +4,6 @@ from .global_vars import set_extra_input_tensor from .global_vars import get_parallel_context from .global_vars import set_parallel_context +from .global_vars import get_spiky_loss_detector +from .global_vars import set_get_spiky_loss_detector from .arguments import FSTrainArguments diff --git a/flagscale/train/global_vars.py b/flagscale/train/global_vars.py index d25d30650..1a2b54e2f 100644 --- a/flagscale/train/global_vars.py +++ b/flagscale/train/global_vars.py @@ -1,11 +1,12 @@ import torch from flagscale.train.hetero.parallel_context import ParallelContext +from flagscale.train.spiky_loss import SpikyLossDetector _GLOBAL_EXTRA_VALID_DATASETS = None _GLOBAL_EXATRA_INPUT_TENSOR = None _GLOBAL_PARALLEL_CONTEXT = None - +_GLOBAL_SPIKY_LOSS_DETECTOR = None def _ensure_var_is_initialized(var, name): """Make sure the input variable is not None.""" @@ -49,3 +50,15 @@ def set_parallel_context(args): global _GLOBAL_PARALLEL_CONTEXT _ensure_var_is_not_initialized(_GLOBAL_PARALLEL_CONTEXT, 'parallel context') _GLOBAL_PARALLEL_CONTEXT = ParallelContext(args) + +def get_spiky_loss_detector(): + """Return spiky loss detector.""" + _ensure_var_is_initialized(_GLOBAL_SPIKY_LOSS_DETECTOR, 'spiky loss detector') + return _GLOBAL_SPIKY_LOSS_DETECTOR + + +def set_get_spiky_loss_detector(args): + """Initialize spiky loss detector.""" + global _GLOBAL_SPIKY_LOSS_DETECTOR + _ensure_var_is_not_initialized(_GLOBAL_SPIKY_LOSS_DETECTOR, 'spiky loss detector') + _GLOBAL_SPIKY_LOSS_DETECTOR = SpikyLossDetector(args.spiky_loss_threshold) \ No newline at end of file diff --git a/flagscale/train/spiky_loss.py b/flagscale/train/spiky_loss.py new file mode 100644 index 000000000..cf943cd8c --- /dev/null +++ b/flagscale/train/spiky_loss.py @@ -0,0 +1,52 @@ +import math +import torch + +class SpikyLossDetector: + '''This class represents a Spiky Loss Detector. + It is used to detect spikes in loss values during training. + ''' + def __init__(self, threshold=0.2, loss = None): + self.last_loss = loss + self.threshold = threshold + + def reduce_losses(self, losses_reduced): + loss_reduced = {} + from megatron.core import mpu + if mpu.is_pipeline_last_stage(ignore_virtual=True): + # Average loss across microbatches. + for key in losses_reduced[0].keys(): + numerator = 0 + denominator = 0 + for x in losses_reduced: + val = x[key] + # there is one dict per microbatch. in new reporting, we average + # over the total number of tokens across the global batch. + if isinstance(val, tuple) or isinstance(val, list): + numerator += val[0] + denominator += val[1] + else: + # legacy behavior. we average over the number of microbatches, + # and so the denominator is 1. + numerator += val + denominator += 1 + loss_reduced[key] = numerator / denominator + return loss_reduced.get('lm loss') + + def is_spkiy_loss(self, loss): + if loss is None: + return False + if self.last_loss is not None: + if math.isnan(loss) or math.isnan(self.last_loss): + self.last_loss = loss + elif math.isinf(loss) or math.isinf(self.last_loss): + return True + else: + result = (loss - self.last_loss) / self.last_loss >= self.threshold + if result: + return True + else: + self.last_loss = loss + else: + self.last_loss = loss + return False + diff --git a/flagscale/train/train.py b/flagscale/train/train.py index 109373299..1769197a6 100644 --- a/flagscale/train/train.py +++ b/flagscale/train/train.py @@ -97,7 +97,7 @@ from flagscale.train.extra_valid import extra_evaluate_and_print_results from flagscale.train.extra_valid import build_extra_valid_data_iterators from flagscale.train.stablelm2_scheduler import StableLM2SchedulerConfig -from flagscale.train.global_vars import get_parallel_context +from flagscale.train.global_vars import get_parallel_context, get_spiky_loss_detector from flagscale.train.hetero.p2p_communication import get_device_type_for_comm stimer = StragglerDetector() @@ -832,6 +832,18 @@ def train_step(forward_step_func, data_iterator, if should_exit: return {}, True, should_checkpoint, should_exit, exit_code, None, None + ########## FlagScale Begin ########## + if args.auto_skip_spiky_loss and (args.consumed_train_samples > args.lr_warmup_samples and args.curr_iteration > args.lr_warmup_iters): + spiky_loss_detector = get_spiky_loss_detector() + loss_ = spiky_loss_detector.reduce_losses(losses_reduced) + is_spiky_loss = spiky_loss_detector.is_spkiy_loss(loss_) + is_spiky_loss_tensor = torch.tensor(is_spiky_loss, dtype=torch.int, device="cuda") + torch.distributed.all_reduce(is_spiky_loss_tensor, op=torch.distributed.ReduceOp.MAX) + is_spiky_loss = is_spiky_loss_tensor.item() + if is_spiky_loss > 0: + return {}, True, should_checkpoint, should_exit, exit_code, None, None + ########## FlagScale Begin ########## + # Empty unused memory. if args.empty_unused_memory_level >= 1: torch.cuda.empty_cache() @@ -1573,6 +1585,35 @@ def get_e2e_base_metrics(): # Run training step. args.curr_iteration = iteration + + ########## FlagScale Begin ########## + if args.skip_samples_range or args.skip_iters_range: + current_global_batch_size = get_current_global_batch_size() + start_skip_iteration = 0 + end_skip_iteration = 0 + if args.skip_samples_range: + if args.consumed_train_samples + current_global_batch_size > args.skip_samples_range[0] and args.consumed_train_samples < args.skip_samples_range[1]: + num_skipped_iters = (args.skip_samples_range[1] - args.consumed_train_samples + current_global_batch_size - 1) // current_global_batch_size + args.skip_samples_range[1] = args.consumed_train_samples + num_skipped_iters * current_global_batch_size + start_skip_iteration = iteration + end_skip_iteration = iteration + num_skipped_iters + else: + if iteration >= args.skip_iters_range[0] and iteration < args.skip_iters_range[1]: + start_skip_iteration = iteration + end_skip_iteration = args.skip_iters_range[1] + while iteration >= start_skip_iteration and iteration < end_skip_iteration: + if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage(): + for _ in range(get_num_microbatches()): + _ = next(train_data_iterator) + args.consumed_train_samples += mpu.get_data_parallel_world_size() * \ + args.micro_batch_size * \ + get_num_microbatches() + update_num_microbatches(args.consumed_train_samples, consistency_check=True, verbose=True) + iteration += 1 + + args.curr_iteration = iteration + ########## FlagScale Begin ########## + loss_dict, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad = \ train_step(forward_step_func, train_data_iterator, diff --git a/megatron/megatron/training/arguments.py b/megatron/megatron/training/arguments.py index b5d9dd8ed..cc2ed104e 100644 --- a/megatron/megatron/training/arguments.py +++ b/megatron/megatron/training/arguments.py @@ -62,6 +62,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False): parser = _add_rerun_machine_args(parser) parser = _add_hetero_args(parser) parser = _add_auto_tuner_args(parser) + parser = _add_auto_skip_spiky_loss(parser) # Custom arguments. if extra_args_provider is not None: @@ -1405,6 +1406,10 @@ def _add_training_args(parser): help='Total number of samples to train over all ' 'training runs. Note that either train-iters or ' 'train-samples should be provided.') + group.add_argument('--skip-samples-range', nargs='+', type=int, default=None, + help='Range of samples to skip during training.') + group.add_argument('--skip-iters-range', nargs='+', type=int, default=None, + help='Range of iterations to skip during training.') group.add_argument('--log-interval', type=int, default=100, help='Report loss and timing interval.') group.add_argument('--exit-interval', type=int, default=None, @@ -2360,3 +2365,13 @@ def _add_auto_tuner_args(parser): help='use auto tuner') return parser + + +def _add_auto_skip_spiky_loss(parser): + group = parser.add_argument_group(title='auto skip spiky loss') + + group.add_argument('--auto-skip-spiky-loss', action='store_true', + help='Automatically skip spiky loss iterations.') + group.add_argument('--spiky-loss-threshold', type=float, default=0.2, + help='Threshold for skipping spiky loss iterations.') + return parser diff --git a/megatron/megatron/training/initialize.py b/megatron/megatron/training/initialize.py index f8bc22df1..784dc713a 100644 --- a/megatron/megatron/training/initialize.py +++ b/megatron/megatron/training/initialize.py @@ -28,7 +28,7 @@ from megatron.core.utils import get_te_version, is_te_min_version, is_torch_min_version from flagscale.train import FSTrainArguments -from flagscale.train import set_parallel_context +from flagscale.train import set_parallel_context, set_get_spiky_loss_detector logger = logging.getLogger(__name__) @@ -106,6 +106,9 @@ def state_restore_func(state_dict): error_injection_type=RerunDiagnostic(args.error_injection_type), ), ) + + if args.auto_skip_spiky_loss: + set_get_spiky_loss_detector(args=args) # torch.distributed initialization def finish_mpu_init(): diff --git a/tests/functional_tests/test_cases/hetero_train/aquila/conf/train/tp2pp1_tp4pp1_tp2pp1.yaml b/tests/functional_tests/test_cases/hetero_train/aquila/conf/train/tp2pp1_tp4pp1_tp2pp1.yaml index d4cd13df8..4db8f58b0 100644 --- a/tests/functional_tests/test_cases/hetero_train/aquila/conf/train/tp2pp1_tp4pp1_tp2pp1.yaml +++ b/tests/functional_tests/test_cases/hetero_train/aquila/conf/train/tp2pp1_tp4pp1_tp2pp1.yaml @@ -71,6 +71,8 @@ model: micro_batch_size: 4 global_batch_size: 1024 seed: 42 + auto_skip_spiky_loss: true + spiky_loss_threshold: 0.25 optimizer: weight_decay: 0.1 diff --git a/tests/unit_tests/test_spiky_loss_detector.py b/tests/unit_tests/test_spiky_loss_detector.py new file mode 100644 index 000000000..3bc776f8f --- /dev/null +++ b/tests/unit_tests/test_spiky_loss_detector.py @@ -0,0 +1,35 @@ +import torch + +from flagscale.train.spiky_loss import SpikyLossDetector +from tests.unit_tests.test_utilities import Utils + +def test_spiky_loss_detector(pp_size=2, threshold=0.2): + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=pp_size, + expert_model_parallel_size=1, + context_parallel_size=1, + expert_tensor_parallel_size=1, + ) + + detector = SpikyLossDetector(threshold=threshold, loss=10.0) + + # test case 1: loss is not spiky + losses = [{"lm loss": 10.23}, {"lm loss": 10.32}, {"lm loss": 10.30}] + reduced_loss = detector.reduce_losses(losses) + is_spiky_loss = detector.is_spkiy_loss(reduced_loss) + is_spiky_loss_tensor = torch.tensor(is_spiky_loss, dtype=torch.int, device="cuda") + torch.distributed.all_reduce(is_spiky_loss_tensor, op=torch.distributed.ReduceOp.MAX) + is_spiky_loss = is_spiky_loss_tensor.item() + assert is_spiky_loss == 0, f"Expected 0, got {is_spiky_loss}" + + # test case 2: loss is spiky + losses = [{"lm loss": 14.23}, {"lm loss": 14.32}, {"lm loss": 14.30}] + reduced_loss = detector.reduce_losses(losses) + is_spiky_loss = detector.is_spkiy_loss(reduced_loss) + is_spiky_loss_tensor = torch.tensor(is_spiky_loss, dtype=torch.int, device="cuda") + torch.distributed.all_reduce(is_spiky_loss_tensor, op=torch.distributed.ReduceOp.MAX) + is_spiky_loss = is_spiky_loss_tensor.item() + assert is_spiky_loss == 1, f"Expected 1, got {is_spiky_loss}" + + Utils.destroy_model_parallel() \ No newline at end of file