diff --git a/examples/emu/conf/compress/compress_emu3.yaml b/examples/emu/conf/compress/compress_emu3.yaml index 38dc4cc73..13e18481a 100644 --- a/examples/emu/conf/compress/compress_emu3.yaml +++ b/examples/emu/conf/compress/compress_emu3.yaml @@ -3,20 +3,16 @@ defaults: - _self_ data: - data_path: /share/project/lms/emu3_testdata/3_text_document - num_calibration_steps: 16 - max_seq_length: 9216 + data_path: null + max_calib_data: null + max_seq_len: null tokenzier_args: - tokenizer_path: /share/project/lms/Emu3-Gen/ - special_tokens_file: /share/project/lms/Emu3-Gen/emu3_vision_tokens.txt + tokenizer_path: BAAI/Emu3-Gen/ + special_tokens_file: BAAI/Emu3-Gen/emu3_vision_tokens.txt trust_remote_code: true compress_args: quantization: - - algo: - smoothquant: - smoothing_strength: 0.5 - ignore: ["lm_head"] - algo: targets: ["Linear"] ignore: ["lm_head"] diff --git a/examples/emu/conf/compress/compress_emu3_w4a16.yaml b/examples/emu/conf/compress/compress_emu3_w4a16.yaml index 1c431216c..534727f67 100644 --- a/examples/emu/conf/compress/compress_emu3_w4a16.yaml +++ b/examples/emu/conf/compress/compress_emu3_w4a16.yaml @@ -3,20 +3,16 @@ defaults: - _self_ data: - data_path: /share/project/lms/emu3_testdata/3_text_document - num_calibration_steps: 16 + data_path: + num_calibration_samples: 16 max_seq_length: 9216 tokenzier_args: - tokenizer_path: /share/project/lms/Emu3-Gen/ - special_tokens_file: /share/project/lms/Emu3-Gen/emu3_vision_tokens.txt + tokenizer_path: BAAI/Emu3-Gen/ + special_tokens_file: BAAI/Emu3-Gen/emu3_vision_tokens.txt trust_remote_code: true compress_args: quantization: - - algo: - smoothquant: - smoothing_strength: 0.5 - ignore: ["lm_head"] - algo: gptq: blocksize: 128 diff --git a/examples/emu/conf/compress/emu3_model.yaml b/examples/emu/conf/compress/emu3_model.yaml index f9485560f..16751891d 100644 --- a/examples/emu/conf/compress/emu3_model.yaml +++ b/examples/emu/conf/compress/emu3_model.yaml @@ -8,11 +8,12 @@ system: tensorboard_log_interval: 1 wandb_project: "compress-emu3-7B" wandb_exp_name: "compress-test-7B" - save_dir: outputs/emu3/inference_model + save_dir: model: model_cls: AutoModelForCausalLM - model_path: /share/project/lms/Emu3-Gen/ + model_path: BAAI/Emu3-Gen/ device_map: cuda:0 trust_remote_code: true + torch_dtype: bfloat16 diff --git a/flagscale/compress/adapter.py b/flagscale/compress/adapter.py index 45c8c9650..e055aec99 100644 --- a/flagscale/compress/adapter.py +++ b/flagscale/compress/adapter.py @@ -22,7 +22,6 @@ QuantizationScheme, disable_quantization, enable_quantization, - is_attention_module, ) from llmcompressor.modifiers.quantization.calibration import ( apply_calibration_status, @@ -40,7 +39,6 @@ from flagscale.compress.blockwise_compressor import BlockCompressor from flagscale.runner.runner_utils import logger -import pdb __all__ = ["LLMCompressorAdapter"] @@ -55,8 +53,6 @@ class LLMCompressorAdapter: def __init__(self, model, scheme=None, targets=None, algo=None, ignore=None, dataset=None, num_calibration_steps=384): self.model = model - # print("model: ", model) - # modify_save_pretrained(self.model) if algo is not None: assert len(algo) == 1 for k, v in algo.items(): @@ -91,7 +87,6 @@ def __init__(self, model, scheme=None, targets=None, algo=None, ignore=None, dat self.wrapper_cls = RTNWrapper self.compress_granularity = LayerCompressor quant_config = self.init_quant_config() - print(quant_config) if quant_config is not None: ### find ignore and target to quant, initialize module for quant @@ -101,7 +96,6 @@ def __init__(self, model, scheme=None, targets=None, algo=None, ignore=None, dat self.init_compressor() if self.require_calib: - # self.insert_observer() if model.training == False: ### Post Training assert self.dataset is not None, f"The algorithm {self.algo} you selected requires a calibration process, please provide the calibration data" self.run_blockwise_calib_forward() @@ -112,11 +106,11 @@ def __init__(self, model, scheme=None, targets=None, algo=None, ignore=None, dat self.layer_compressors_[0].clear_early_stop() for idx, layer_compressor in enumerate(self.layer_compressors_): layer_compressor.pre_compress() - # import pdb;pdb.set_trace() layer_compressor.compress() layer_compressor.post_compress() layer_compressor.revert_layer_wrappers() + def init_quant_config(self): if self.scheme is not None: # takes precedence over config_groups @@ -182,7 +176,6 @@ def run_blockwise_calib_forward(self): for idx, layer_compressor in enumerate(self.layer_compressors_): logger.info(f"start calibration layer {layer_compressor.name}") layer_compressor.pre_compress() - # print("idx: ", idx, intermediates) unquantized_outputs = layer_compressor.calibrate_layer(intermediates) layer_compressor.compress() layer_compressor.post_compress() @@ -192,4 +185,3 @@ def run_blockwise_calib_forward(self): logger.info(f"Mean output error from quantization: {error:.3f}") intermediates = quantized_outputs self.model.apply(enable_quantization) - diff --git a/flagscale/compress/algo/algo_base.py b/flagscale/compress/algo/algo_base.py deleted file mode 100644 index 582f76238..000000000 --- a/flagscale/compress/algo/algo_base.py +++ /dev/null @@ -1,23 +0,0 @@ -from abc import ABC, abstractmethod -from torch.nn import Module - -class BaseWrapper(Module, ABC): - def __init__(self, name, layer): - super(BaseWrapper, self).__init__(name, layer) - self.name = name - ### disable _enable_compress means only observer - self._enable_compress = False - - def add_batch(self): - raise NotImplementedError - - def compress(self): - raise NotImplementedError - - @setattr - def enable_compress(self): - self._enable_compress = True - - @setattr - def disable_compress(self): - self._enable_compress = False \ No newline at end of file diff --git a/flagscale/compress/compressor.py b/flagscale/compress/compressor.py index 6c6a881cc..1a2182dda 100644 --- a/flagscale/compress/compressor.py +++ b/flagscale/compress/compressor.py @@ -75,7 +75,6 @@ def compress(self): if self.model is None: model_cls = eval(self.cfg.model.pop("model_cls")) self.model = model_cls.from_pretrained(self.model_path, **self.cfg.model) - # import pdb; pdb.set_trace() assert isinstance(self.model, torch.nn.Module), f"model type {type(self.model)} error, please check it" compress_args = self.cfg.compress_args recipes = prepare_compress_methods(compress_args) @@ -115,4 +114,4 @@ def convert(self, model): args = parser.parse_args() cfg = prepare_config(args.config_path) - Compressor(cfg) + Compressor(cfg) \ No newline at end of file diff --git a/flagscale/compress/compressor_emu3.py b/flagscale/compress/compressor_emu3.py index 6fa128ac0..11c7076f1 100644 --- a/flagscale/compress/compressor_emu3.py +++ b/flagscale/compress/compressor_emu3.py @@ -61,16 +61,4 @@ def prepare_dataset(cfg): dataset = prepare_dataset(cfg) cmp = Compressor(cfg, dataset=dataset) cmp.compress() - model = cmp.convert(cmp.model) - ### test code - with torch.no_grad(): - from llmcompressor.pytorch.utils import tensors_to_device - model_device = next(model.parameters()).device - for idx, data in enumerate(dataset): - data = tensors_to_device(data, model_device) - if idx < 2: - model(**data) - else: - break - diff --git a/flagscale/compress/compressor_llava_ov.py b/flagscale/compress/compressor_llava_ov.py index 3a9c21e30..497d1860f 100644 --- a/flagscale/compress/compressor_llava_ov.py +++ b/flagscale/compress/compressor_llava_ov.py @@ -100,8 +100,6 @@ def prepare_dataset(cfg, model, tokenizer): elif isinstance(data_args.image_grid_pinpoints, str): data_args.image_grid_pinpoints = ast.literal_eval(data_args.image_grid_pinpoints) dataset = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) - import pdb - pdb.set_trace() ds = CusDataset(dataset["train_dataset"]) return ds diff --git a/flagscale/train/__init__.py b/flagscale/train/__init__.py index a31eb1491..3ce739be5 100644 --- a/flagscale/train/__init__.py +++ b/flagscale/train/__init__.py @@ -4,4 +4,6 @@ from .global_vars import set_extra_input_tensor from .global_vars import get_parallel_context from .global_vars import set_parallel_context +from .global_vars import get_spiky_loss_detector +from .global_vars import set_get_spiky_loss_detector from .arguments import FSTrainArguments diff --git a/flagscale/train/global_vars.py b/flagscale/train/global_vars.py index d25d30650..1a2b54e2f 100644 --- a/flagscale/train/global_vars.py +++ b/flagscale/train/global_vars.py @@ -1,11 +1,12 @@ import torch from flagscale.train.hetero.parallel_context import ParallelContext +from flagscale.train.spiky_loss import SpikyLossDetector _GLOBAL_EXTRA_VALID_DATASETS = None _GLOBAL_EXATRA_INPUT_TENSOR = None _GLOBAL_PARALLEL_CONTEXT = None - +_GLOBAL_SPIKY_LOSS_DETECTOR = None def _ensure_var_is_initialized(var, name): """Make sure the input variable is not None.""" @@ -49,3 +50,15 @@ def set_parallel_context(args): global _GLOBAL_PARALLEL_CONTEXT _ensure_var_is_not_initialized(_GLOBAL_PARALLEL_CONTEXT, 'parallel context') _GLOBAL_PARALLEL_CONTEXT = ParallelContext(args) + +def get_spiky_loss_detector(): + """Return spiky loss detector.""" + _ensure_var_is_initialized(_GLOBAL_SPIKY_LOSS_DETECTOR, 'spiky loss detector') + return _GLOBAL_SPIKY_LOSS_DETECTOR + + +def set_get_spiky_loss_detector(args): + """Initialize spiky loss detector.""" + global _GLOBAL_SPIKY_LOSS_DETECTOR + _ensure_var_is_not_initialized(_GLOBAL_SPIKY_LOSS_DETECTOR, 'spiky loss detector') + _GLOBAL_SPIKY_LOSS_DETECTOR = SpikyLossDetector(args.spiky_loss_threshold) \ No newline at end of file diff --git a/flagscale/train/spiky_loss.py b/flagscale/train/spiky_loss.py new file mode 100644 index 000000000..cf943cd8c --- /dev/null +++ b/flagscale/train/spiky_loss.py @@ -0,0 +1,52 @@ +import math +import torch + +class SpikyLossDetector: + '''This class represents a Spiky Loss Detector. + It is used to detect spikes in loss values during training. + ''' + def __init__(self, threshold=0.2, loss = None): + self.last_loss = loss + self.threshold = threshold + + def reduce_losses(self, losses_reduced): + loss_reduced = {} + from megatron.core import mpu + if mpu.is_pipeline_last_stage(ignore_virtual=True): + # Average loss across microbatches. + for key in losses_reduced[0].keys(): + numerator = 0 + denominator = 0 + for x in losses_reduced: + val = x[key] + # there is one dict per microbatch. in new reporting, we average + # over the total number of tokens across the global batch. + if isinstance(val, tuple) or isinstance(val, list): + numerator += val[0] + denominator += val[1] + else: + # legacy behavior. we average over the number of microbatches, + # and so the denominator is 1. + numerator += val + denominator += 1 + loss_reduced[key] = numerator / denominator + return loss_reduced.get('lm loss') + + def is_spkiy_loss(self, loss): + if loss is None: + return False + if self.last_loss is not None: + if math.isnan(loss) or math.isnan(self.last_loss): + self.last_loss = loss + elif math.isinf(loss) or math.isinf(self.last_loss): + return True + else: + result = (loss - self.last_loss) / self.last_loss >= self.threshold + if result: + return True + else: + self.last_loss = loss + else: + self.last_loss = loss + return False + diff --git a/flagscale/train/train.py b/flagscale/train/train.py index 109373299..1769197a6 100644 --- a/flagscale/train/train.py +++ b/flagscale/train/train.py @@ -97,7 +97,7 @@ from flagscale.train.extra_valid import extra_evaluate_and_print_results from flagscale.train.extra_valid import build_extra_valid_data_iterators from flagscale.train.stablelm2_scheduler import StableLM2SchedulerConfig -from flagscale.train.global_vars import get_parallel_context +from flagscale.train.global_vars import get_parallel_context, get_spiky_loss_detector from flagscale.train.hetero.p2p_communication import get_device_type_for_comm stimer = StragglerDetector() @@ -832,6 +832,18 @@ def train_step(forward_step_func, data_iterator, if should_exit: return {}, True, should_checkpoint, should_exit, exit_code, None, None + ########## FlagScale Begin ########## + if args.auto_skip_spiky_loss and (args.consumed_train_samples > args.lr_warmup_samples and args.curr_iteration > args.lr_warmup_iters): + spiky_loss_detector = get_spiky_loss_detector() + loss_ = spiky_loss_detector.reduce_losses(losses_reduced) + is_spiky_loss = spiky_loss_detector.is_spkiy_loss(loss_) + is_spiky_loss_tensor = torch.tensor(is_spiky_loss, dtype=torch.int, device="cuda") + torch.distributed.all_reduce(is_spiky_loss_tensor, op=torch.distributed.ReduceOp.MAX) + is_spiky_loss = is_spiky_loss_tensor.item() + if is_spiky_loss > 0: + return {}, True, should_checkpoint, should_exit, exit_code, None, None + ########## FlagScale Begin ########## + # Empty unused memory. if args.empty_unused_memory_level >= 1: torch.cuda.empty_cache() @@ -1573,6 +1585,35 @@ def get_e2e_base_metrics(): # Run training step. args.curr_iteration = iteration + + ########## FlagScale Begin ########## + if args.skip_samples_range or args.skip_iters_range: + current_global_batch_size = get_current_global_batch_size() + start_skip_iteration = 0 + end_skip_iteration = 0 + if args.skip_samples_range: + if args.consumed_train_samples + current_global_batch_size > args.skip_samples_range[0] and args.consumed_train_samples < args.skip_samples_range[1]: + num_skipped_iters = (args.skip_samples_range[1] - args.consumed_train_samples + current_global_batch_size - 1) // current_global_batch_size + args.skip_samples_range[1] = args.consumed_train_samples + num_skipped_iters * current_global_batch_size + start_skip_iteration = iteration + end_skip_iteration = iteration + num_skipped_iters + else: + if iteration >= args.skip_iters_range[0] and iteration < args.skip_iters_range[1]: + start_skip_iteration = iteration + end_skip_iteration = args.skip_iters_range[1] + while iteration >= start_skip_iteration and iteration < end_skip_iteration: + if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage(): + for _ in range(get_num_microbatches()): + _ = next(train_data_iterator) + args.consumed_train_samples += mpu.get_data_parallel_world_size() * \ + args.micro_batch_size * \ + get_num_microbatches() + update_num_microbatches(args.consumed_train_samples, consistency_check=True, verbose=True) + iteration += 1 + + args.curr_iteration = iteration + ########## FlagScale Begin ########## + loss_dict, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad = \ train_step(forward_step_func, train_data_iterator, diff --git a/megatron/megatron/core/transformer/moe/moe_utils.py b/megatron/megatron/core/transformer/moe/moe_utils.py index ac3357ed1..49e96bcd0 100644 --- a/megatron/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/megatron/core/transformer/moe/moe_utils.py @@ -94,6 +94,20 @@ def sequence_load_balancing_loss_func( return seq_aux_loss +def score_function( + input: torch.Tensor, + score_function_type: str = "softmax", + target_dtype: torch.dtype = None, +): + if score_function_type == "softmax": + scores = torch.softmax(input, dim=-1, dtype=torch.float32) + if target_dtype: + scores = scores.type(target_dtype) + elif score_function_type == "sigmoid": + scores = input.sigmoid() + else: + raise ValueError(f"Unsupported MoE routing score function type: {score_function_type}") + return scores def z_loss_func(logits, z_loss_coeff): """Encourages the router's logits to remain small to enhance stability. @@ -323,6 +337,7 @@ def topk_softmax_with_capacity( moe_router_topk_limited_devices: int = None, moe_router_topk_scaling_factor: float = None, deterministic_mode: bool = False, + score_function_type: str = "softmax", ): """Apply capacity and padding to the top-k selection. Args: @@ -355,7 +370,7 @@ def topk_softmax_with_capacity( num_experts = logits.shape[1] if use_pre_softmax: # Pre softmax - scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) + scores = score_function(logits, score_function_type, logits.dtype) if moe_router_topk_limited_devices: probs, top_indices = device_limited_topk( @@ -382,7 +397,11 @@ def topk_softmax_with_capacity( ) else: scores, top_indices = torch.topk(logits, k=topk, dim=1) - probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits) + probs = score_function(scores, score_function_type, logits.dtype) + + if score_function_type == "sigmoid": + tmp = probs.sum(dim=-1, keepdim=True) + probs = probs / tmp # TODO Try using element-wise operations instead of scatter? topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs) diff --git a/megatron/megatron/core/transformer/moe/router.py b/megatron/megatron/core/transformer/moe/router.py index 82d1029a5..a628e34a5 100644 --- a/megatron/megatron/core/transformer/moe/router.py +++ b/megatron/megatron/core/transformer/moe/router.py @@ -17,6 +17,7 @@ switch_load_balancing_loss_func, topk_softmax_with_capacity, z_loss_func, + score_function, ) from megatron.core.transformer.transformer_config import TransformerConfig @@ -102,6 +103,7 @@ def __init__(self, config: TransformerConfig) -> None: super().__init__(config=config) self.topk = self.config.moe_router_topk self.routing_type = self.config.moe_router_load_balancing_type + self.score_function_type = self.config.moe_router_score_function_type self.input_jitter = None def sinkhorn_load_balancing(self, logits: torch.Tensor): @@ -157,11 +159,15 @@ def aux_loss_load_balancing(self, logits: torch.Tensor): moe_router_topk_limited_devices=self.config.moe_router_topk_limited_devices, moe_router_topk_scaling_factor=self.config.moe_router_topk_scaling_factor, deterministic_mode=self.config.deterministic_mode, + score_function_type=self.score_function_type, ) if self.training: # Apply load balancing loss - scores = torch.softmax(logits, dim=-1, dtype=torch.float32) + scores = score_function(logits, self.score_function_type) + if self.score_function_type == "sigmoid": + tmp = scores.sum(dim=-1, keepdim=True) + scores = scores / tmp aux_loss_func = partial( switch_load_balancing_loss_func, probs=scores, @@ -186,10 +192,14 @@ def seq_aux_loss_load_balancing(self, logits: torch.Tensor, bsz: int, seq_length moe_router_topk_limited_devices=self.config.moe_router_topk_limited_devices, moe_router_topk_scaling_factor=self.config.moe_router_topk_scaling_factor, deterministic_mode=self.config.deterministic_mode, + score_function_type=self.score_function_type, ) if self.training: - scores = torch.softmax(logits, dim=-1, dtype=torch.float32) + scores = score_function(logits, self.score_function_type) + if self.score_function_type == "sigmoid": + tmp = scores.sum(dim=-1, keepdim=True) + scores = scores / tmp aux_loss_func = partial( sequence_load_balancing_loss_func, probs=scores, diff --git a/megatron/megatron/core/transformer/transformer_config.py b/megatron/megatron/core/transformer/transformer_config.py index 7f1457eb0..206f89b96 100644 --- a/megatron/megatron/core/transformer/transformer_config.py +++ b/megatron/megatron/core/transformer/transformer_config.py @@ -283,6 +283,9 @@ class TransformerConfig(ModelParallelConfig): which computes the loss for each individual sample; "sinkhorn" corresponds to the balancing algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss".""" + moe_router_score_function_type: str = "softmax" + """Determines the score function type for the router, currently support two load balancing type: "aux_loss" and "seq_aux_loss".""" + moe_router_topk: int = 2 """Number of experts to route to for each token.""" diff --git a/megatron/megatron/training/arguments.py b/megatron/megatron/training/arguments.py index 5692f5581..cc2ed104e 100644 --- a/megatron/megatron/training/arguments.py +++ b/megatron/megatron/training/arguments.py @@ -62,6 +62,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False): parser = _add_rerun_machine_args(parser) parser = _add_hetero_args(parser) parser = _add_auto_tuner_args(parser) + parser = _add_auto_skip_spiky_loss(parser) # Custom arguments. if extra_args_provider is not None: @@ -1405,6 +1406,10 @@ def _add_training_args(parser): help='Total number of samples to train over all ' 'training runs. Note that either train-iters or ' 'train-samples should be provided.') + group.add_argument('--skip-samples-range', nargs='+', type=int, default=None, + help='Range of samples to skip during training.') + group.add_argument('--skip-iters-range', nargs='+', type=int, default=None, + help='Range of iterations to skip during training.') group.add_argument('--log-interval', type=int, default=100, help='Report loss and timing interval.') group.add_argument('--exit-interval', type=int, default=None, @@ -2230,6 +2235,10 @@ def _add_moe_args(parser): choices=['aux_loss', 'seq_aux_loss', 'sinkhorn', 'none'], default='aux_loss', help='Determines the load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss used in GShard and SwitchTransformer; "seq_aux_loss" corresponds to the load balancing loss used in DeepSeekV2, which computes the loss for each individual sample; "sinkhorn" corresponds to the balancing algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss".') + group.add_argument('--moe-router-score-function-type', type=str, + choices=['softmax', 'sigmoid'], + default='softmax', + help='Determines the score function type for the router, currently support two load balancing type: "aux_loss" and "seq_aux_loss".') group.add_argument('--moe-router-topk', type=int, default=2, help='Number of experts to route to for each token. The default is 2.') group.add_argument('--moe-router-pre-softmax', action='store_true', @@ -2356,3 +2365,13 @@ def _add_auto_tuner_args(parser): help='use auto tuner') return parser + + +def _add_auto_skip_spiky_loss(parser): + group = parser.add_argument_group(title='auto skip spiky loss') + + group.add_argument('--auto-skip-spiky-loss', action='store_true', + help='Automatically skip spiky loss iterations.') + group.add_argument('--spiky-loss-threshold', type=float, default=0.2, + help='Threshold for skipping spiky loss iterations.') + return parser diff --git a/megatron/megatron/training/initialize.py b/megatron/megatron/training/initialize.py index f8bc22df1..784dc713a 100644 --- a/megatron/megatron/training/initialize.py +++ b/megatron/megatron/training/initialize.py @@ -28,7 +28,7 @@ from megatron.core.utils import get_te_version, is_te_min_version, is_torch_min_version from flagscale.train import FSTrainArguments -from flagscale.train import set_parallel_context +from flagscale.train import set_parallel_context, set_get_spiky_loss_detector logger = logging.getLogger(__name__) @@ -106,6 +106,9 @@ def state_restore_func(state_dict): error_injection_type=RerunDiagnostic(args.error_injection_type), ), ) + + if args.auto_skip_spiky_loss: + set_get_spiky_loss_detector(args=args) # torch.distributed initialization def finish_mpu_init(): diff --git a/megatron/megatron/training/theoretical_memory_usage.py b/megatron/megatron/training/theoretical_memory_usage.py index f9b75031a..bfbef7b5b 100644 --- a/megatron/megatron/training/theoretical_memory_usage.py +++ b/megatron/megatron/training/theoretical_memory_usage.py @@ -8,6 +8,56 @@ NUM_BYTES_IN_MEGABYTE = 1024 * 1024 +def compute_activated_weight_number(args, verbose=False): + if args.num_experts is None: + return + # Attention projection size. + query_projection_size = args.kv_channels * args.num_attention_heads + query_projection_to_hidden_size_ratio = query_projection_size / args.hidden_size + # Group Query Attention. + if not args.group_query_attention: + args.num_query_groups = args.num_attention_heads + # MoE. + # NOTE(zhaoyingli): We only compute the number of activated parameters by topk routing. + num_experts = args.moe_router_topk + gated_linear_multiplier = 3 / 2 if args.swiglu else 1 + num_parameters_in_transformer_layers = ( + 2 + * args.num_layers + * args.hidden_size + * args.hidden_size + * ( + # Attention. + ( + (1 + (args.num_query_groups / args.num_attention_heads)) + * query_projection_to_hidden_size_ratio + ) + # MLP. + + ((args.ffn_hidden_size / args.hidden_size) * num_experts * gated_linear_multiplier) + # Transformer layernorms. + + (2 / args.hidden_size) + # Final layernorm. + + (1 / (args.num_layers * args.hidden_size)) + ) + ) + embedding_size = args.hidden_size * args.padded_vocab_size + if args.untie_embeddings_and_output_weights: + num_parameters_in_embedding_layers = 2 * embedding_size + else: + num_parameters_in_embedding_layers = embedding_size + num_total_parameters = num_parameters_in_transformer_layers + num_parameters_in_embedding_layers + if verbose: + print( + f"Number of activated parameters in transformer layers in billions: " + f"{num_parameters_in_transformer_layers / 10**9: .2f}" + ) + print( + f"Number of activated parameters in embedding layers in billions: " + f"{num_parameters_in_embedding_layers / 10**9:.2f}" + ) + print(f"Total number of activated parameters in billions: {num_total_parameters / 10**9:.2f}") + + def compute_weight_and_optimizer_memory(args, verbose=False): # Attention projection size. query_projection_size = args.kv_channels * args.num_attention_heads @@ -164,6 +214,8 @@ def compute_activation_memory(args, num_microbatches, verbose=False): def report_theoretical_memory(args, num_microbatches=None, verbose=False): + compute_activated_weight_number(args, verbose=verbose) + weight_and_optimizer_memory = ( compute_weight_and_optimizer_memory(args, verbose=verbose) / NUM_BYTES_IN_MEGABYTE ) diff --git a/tests/functional_tests/test_cases/hetero_train/aquila/conf/train/tp2pp1_tp4pp1_tp2pp1.yaml b/tests/functional_tests/test_cases/hetero_train/aquila/conf/train/tp2pp1_tp4pp1_tp2pp1.yaml index d4cd13df8..4db8f58b0 100644 --- a/tests/functional_tests/test_cases/hetero_train/aquila/conf/train/tp2pp1_tp4pp1_tp2pp1.yaml +++ b/tests/functional_tests/test_cases/hetero_train/aquila/conf/train/tp2pp1_tp4pp1_tp2pp1.yaml @@ -71,6 +71,8 @@ model: micro_batch_size: 4 global_batch_size: 1024 seed: 42 + auto_skip_spiky_loss: true + spiky_loss_threshold: 0.25 optimizer: weight_decay: 0.1 diff --git a/tests/unit_tests/test_sigmoid_aux_loss.py b/tests/unit_tests/test_sigmoid_aux_loss.py new file mode 100644 index 000000000..c3111be6c --- /dev/null +++ b/tests/unit_tests/test_sigmoid_aux_loss.py @@ -0,0 +1,253 @@ +import pytest +import torch + +from megatron.core import parallel_state +from megatron.training.initialize import _set_random_seed +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.moe.moe_utils import clear_aux_losses_tracker +from tests.unit_tests.test_utilities import Utils +from tests.unit_tests.transformer.moe.test_token_dispatcher import MoEModelTestContainer + + +class AuxlossTestContainer(MoEModelTestContainer): + def __init__( + self, + tp_size, + ep_size, + pp_size, + cp_size=1, + moe_tp_size=None, + data_parallel_random_init=False, + num_moe_experts=8, + moe_router_topk=2, + moe_router_load_balancing_type="aux_loss", + moe_token_dispatcher_type="alltoall", + moe_expert_capacity_factor=None, + moe_pad_expert_input_to_capacity=False, + moe_aux_loss_coeff=0.1, + **kwargs, + ): + self.num_local_experts = num_moe_experts // ep_size + if moe_tp_size is None: + moe_tp_size = tp_size + Utils.initialize_model_parallel( + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=pp_size, + expert_model_parallel_size=ep_size, + context_parallel_size=cp_size, + expert_tensor_parallel_size=moe_tp_size, + ) + _set_random_seed(seed_=123, data_parallel_random_init=data_parallel_random_init) + local_expert_indices_offset = ( + parallel_state.get_expert_model_parallel_rank() * self.num_local_experts + ) + self.local_expert_indices = [ + local_expert_indices_offset + i for i in range(self.num_local_experts) + ] + self.config = TransformerConfig( + tensor_model_parallel_size=tp_size, + expert_model_parallel_size=ep_size, + pipeline_model_parallel_size=pp_size, + context_parallel_size=cp_size, + expert_tensor_parallel_size=moe_tp_size, + moe_router_topk=moe_router_topk, + num_moe_experts=num_moe_experts, + moe_router_load_balancing_type=moe_router_load_balancing_type, + moe_token_dispatcher_type=moe_token_dispatcher_type, + moe_expert_capacity_factor=moe_expert_capacity_factor, + moe_pad_expert_input_to_capacity=moe_pad_expert_input_to_capacity, + moe_aux_loss_coeff=moe_aux_loss_coeff, + num_layers=1, + moe_grouped_gemm=kwargs.get("moe_grouped_gemm", False), + hidden_size=kwargs.get("hidden_size", 16), + num_attention_heads=kwargs.get("num_attention_heads", 8), + use_cpu_initialization=kwargs.get("use_cpu_initialization", True), + sequence_parallel=tp_size > 1, + add_bias_linear=kwargs.get("add_bias_linear", False), + moe_router_score_function_type=kwargs.get("moe_router_score_function_type", "softmax"), + ) + + # init moe layer + self.moe_layer = self.new_moe_layer() + + def partition_input(self, input): + partitioned_input = input.chunk( + parallel_state.get_tensor_and_context_parallel_world_size(), dim=1 + )[parallel_state.get_tensor_and_context_parallel_rank()] + output = partitioned_input.clone().detach() + output.requires_grad = True + return output + + @pytest.mark.internal + def aux_loss_test(self, input, baseline_grad): + partitioned_input = self.partition_input(input) + moe_layer = self.moe_layer + probs, indices = moe_layer.router(partitioned_input) + probs.sum().mul_(0).backward() + aux_loss_grad = partitioned_input.grad + torch.distributed.barrier() + ans = self.partition_input(baseline_grad) + assert torch.allclose(aux_loss_grad, ans), f"Diff: {(aux_loss_grad/ans).mean()}" + loss = parallel_state.get_moe_layer_wise_logging_tracker()['load_balancing_loss'] + clear_aux_losses_tracker() + + +class TestSigmoidAuxLoss: + def setup_method(self, method): + baseline_container = AuxlossTestContainer( + tp_size=1, + ep_size=1, + pp_size=1, + cp_size=1, + num_moe_experts=8, + moe_router_topk=2, + moe_router_load_balancing_type="aux_loss", + moe_token_dispatcher_type="alltoall", + moe_aux_loss_coeff=0.1, + moe_router_score_function_type="sigmoid", + ) + moe_layer = baseline_container.moe_layer + self.input = torch.randn((32, 8, moe_layer.config.hidden_size)).cuda() + self.input.requires_grad = True + probs, indices = moe_layer.router(self.input) + probs.sum().mul_(0).backward() # zero out the main gradients + self.baseline_grad = self.input.grad + self.input.grad = None + clear_aux_losses_tracker() + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.internal + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal + @pytest.mark.parametrize( + "tp_size,ep_size,cp_size", [(8, 1, 1), (4, 2, 1), (1, 1, 8), (2, 1, 4), (2, 2, 2)] + ) + def test_allgather_dispatcher(self, tp_size, ep_size, cp_size): + container = AuxlossTestContainer( + tp_size=tp_size, + ep_size=ep_size, + pp_size=1, + cp_size=cp_size, + num_moe_experts=8, + moe_router_topk=2, + moe_router_load_balancing_type="aux_loss", + moe_token_dispatcher_type="allgather", + moe_aux_loss_coeff=0.1, + moe_router_score_function_type="sigmoid", + ) + container.aux_loss_test(self.input, self.baseline_grad) + + @pytest.mark.internal + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal + @pytest.mark.parametrize( + "tp_size,ep_size,cp_size", [(8, 1, 1), (4, 2, 1), (1, 1, 8), (2, 1, 4), (2, 2, 2)] + ) + def test_a2a_dispatcher(self, tp_size, ep_size, cp_size): + container = AuxlossTestContainer( + tp_size=tp_size, + ep_size=ep_size, + pp_size=1, + cp_size=cp_size, + num_moe_experts=8, + moe_router_topk=2, + moe_router_load_balancing_type="aux_loss", + moe_token_dispatcher_type="alltoall", + moe_aux_loss_coeff=0.1, + moe_router_score_function_type="sigmoid", + ) + container.aux_loss_test(self.input, self.baseline_grad) + + +class TestSigmoidSeqAuxLoss: + def setup_method(self, method): + baseline_container = AuxlossTestContainer( + tp_size=1, + ep_size=1, + pp_size=1, + cp_size=1, + num_moe_experts=8, + moe_router_topk=2, + moe_router_load_balancing_type="seq_aux_loss", + moe_token_dispatcher_type="alltoall", + moe_aux_loss_coeff=0.1, + moe_router_score_function_type="sigmoid", + ) + moe_layer = baseline_container.moe_layer + self.input = torch.randn((32, 8, moe_layer.config.hidden_size)).cuda() + self.input.requires_grad = True + probs, indices = moe_layer.router(self.input) + probs.sum().mul_(0).backward() # zero out the main gradients + self.baseline_grad = self.input.grad + self.input.grad = None + clear_aux_losses_tracker() + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.internal + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal + @pytest.mark.parametrize("tp_size,ep_size,cp_size", [(1, 8, 1)]) + def test_a2a_dispatcher(self, tp_size, ep_size, cp_size): + container = AuxlossTestContainer( + tp_size=tp_size, + ep_size=ep_size, + pp_size=1, + cp_size=cp_size, + num_moe_experts=8, + moe_router_topk=2, + moe_router_load_balancing_type="seq_aux_loss", + moe_token_dispatcher_type="alltoall", + moe_aux_loss_coeff=0.1, + moe_router_score_function_type="sigmoid", + ) + container.aux_loss_test(self.input, self.baseline_grad) + + +class TestSoftmaxSeqAuxLoss: + def setup_method(self, method): + baseline_container = AuxlossTestContainer( + tp_size=1, + ep_size=1, + pp_size=1, + cp_size=1, + num_moe_experts=8, + moe_router_topk=2, + moe_router_load_balancing_type="seq_aux_loss", + moe_token_dispatcher_type="alltoall", + moe_aux_loss_coeff=0.1, + moe_router_score_function_type="softmax", + ) + moe_layer = baseline_container.moe_layer + self.input = torch.randn((32, 8, moe_layer.config.hidden_size)).cuda() + self.input.requires_grad = True + probs, indices = moe_layer.router(self.input) + probs.sum().mul_(0).backward() # zero out the main gradients + self.baseline_grad = self.input.grad + self.input.grad = None + clear_aux_losses_tracker() + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.internal + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal + @pytest.mark.parametrize("tp_size,ep_size,cp_size", [(1, 8, 1)]) + def test_a2a_dispatcher(self, tp_size, ep_size, cp_size): + container = AuxlossTestContainer( + tp_size=tp_size, + ep_size=ep_size, + pp_size=1, + cp_size=cp_size, + num_moe_experts=8, + moe_router_topk=2, + moe_router_load_balancing_type="seq_aux_loss", + moe_token_dispatcher_type="alltoall", + moe_aux_loss_coeff=0.1, + moe_router_score_function_type="softmax", + ) + container.aux_loss_test(self.input, self.baseline_grad) \ No newline at end of file diff --git a/tests/unit_tests/test_spiky_loss_detector.py b/tests/unit_tests/test_spiky_loss_detector.py new file mode 100644 index 000000000..3bc776f8f --- /dev/null +++ b/tests/unit_tests/test_spiky_loss_detector.py @@ -0,0 +1,35 @@ +import torch + +from flagscale.train.spiky_loss import SpikyLossDetector +from tests.unit_tests.test_utilities import Utils + +def test_spiky_loss_detector(pp_size=2, threshold=0.2): + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=pp_size, + expert_model_parallel_size=1, + context_parallel_size=1, + expert_tensor_parallel_size=1, + ) + + detector = SpikyLossDetector(threshold=threshold, loss=10.0) + + # test case 1: loss is not spiky + losses = [{"lm loss": 10.23}, {"lm loss": 10.32}, {"lm loss": 10.30}] + reduced_loss = detector.reduce_losses(losses) + is_spiky_loss = detector.is_spkiy_loss(reduced_loss) + is_spiky_loss_tensor = torch.tensor(is_spiky_loss, dtype=torch.int, device="cuda") + torch.distributed.all_reduce(is_spiky_loss_tensor, op=torch.distributed.ReduceOp.MAX) + is_spiky_loss = is_spiky_loss_tensor.item() + assert is_spiky_loss == 0, f"Expected 0, got {is_spiky_loss}" + + # test case 2: loss is spiky + losses = [{"lm loss": 14.23}, {"lm loss": 14.32}, {"lm loss": 14.30}] + reduced_loss = detector.reduce_losses(losses) + is_spiky_loss = detector.is_spkiy_loss(reduced_loss) + is_spiky_loss_tensor = torch.tensor(is_spiky_loss, dtype=torch.int, device="cuda") + torch.distributed.all_reduce(is_spiky_loss_tensor, op=torch.distributed.ReduceOp.MAX) + is_spiky_loss = is_spiky_loss_tensor.item() + assert is_spiky_loss == 1, f"Expected 1, got {is_spiky_loss}" + + Utils.destroy_model_parallel() \ No newline at end of file