From 9b3e21647048ab6c20c589a7b0f5f80b3eca4984 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 17 Feb 2025 13:57:57 -0500 Subject: [PATCH 01/22] remove pre_initialize_structure Signed-off-by: Kyle Sayers --- src/llmcompressor/__init__.py | 1 - src/llmcompressor/core/__init__.py | 2 - src/llmcompressor/core/lifecycle.py | 53 +------------------ src/llmcompressor/core/session.py | 39 -------------- src/llmcompressor/core/session_functions.py | 11 ---- src/llmcompressor/modifiers/interface.py | 19 ------- src/llmcompressor/modifiers/modifier.py | 31 ----------- .../modifiers/quantization/gptq/base.py | 18 +++---- src/llmcompressor/modifiers/stage.py | 24 +-------- .../pytorch/model_load/helpers.py | 3 +- .../transformers/finetune/runner.py | 6 --- .../transformers/finetune/session_mixin.py | 30 ++--------- .../transformers/finetune/text_generation.py | 7 +-- .../pruning/sparsegpt/test_pytorch.py | 8 +-- .../finetune/test_session_mixin.py | 1 - 15 files changed, 18 insertions(+), 235 deletions(-) diff --git a/src/llmcompressor/__init__.py b/src/llmcompressor/__init__.py index 264d434f0..044d449c9 100644 --- a/src/llmcompressor/__init__.py +++ b/src/llmcompressor/__init__.py @@ -41,6 +41,5 @@ create_session, finalize, initialize, - pre_initialize_structure, reset_session, ) diff --git a/src/llmcompressor/core/__init__.py b/src/llmcompressor/core/__init__.py index 171c95395..ed4134af7 100644 --- a/src/llmcompressor/core/__init__.py +++ b/src/llmcompressor/core/__init__.py @@ -16,7 +16,6 @@ create_session, finalize, initialize, - pre_initialize_structure, reset_session, ) from llmcompressor.core.state import Data, Hardware, ModifiedState, State @@ -37,7 +36,6 @@ "create_session", "active_session", "reset_session", - "pre_initialize_structure", "initialize", "finalize", "apply", diff --git a/src/llmcompressor/core/lifecycle.py b/src/llmcompressor/core/lifecycle.py index 232d76b83..f41ab15df 100644 --- a/src/llmcompressor/core/lifecycle.py +++ b/src/llmcompressor/core/lifecycle.py @@ -38,12 +38,11 @@ class CompressionLifecycle: :type event_lifecycle: Optional[EventLifecycle] """ - state: Optional[State] = None + state: Optional[State] = field(default_factory=State) recipe_container: RecipeContainer = field(default_factory=RecipeContainer) modifiers: List[StageModifiers] = field(default_factory=list) event_lifecycle: Optional[EventLifecycle] = None - initialized_structure: bool = False initialized_: bool = False finalized: bool = False event_called: bool = False @@ -64,48 +63,9 @@ def reset(self): except Exception as e: logger.warning(f"Exception during finalizing modifier: {e}") - self.state = None - self.recipe_container = RecipeContainer() - self.modifiers = [] - self.event_lifecycle = None - - self.initialized_structure = False - self.initialized_ = False - self.finalized = False - self.event_called = False + self.__init__() logger.info("Compression lifecycle reset") - def pre_initialize_structure(self, **kwargs) -> List[Any]: - """ - Pre-initialize the structure of the compression lifecycle. - - :param kwargs: Additional arguments to update the state with - :return: List of data returned from pre-initialization of modifiers - :rtype: List[Any] - """ - logger.debug("Pre-initializing structure") - self._check_create_state() - extras = self.state.update(**kwargs) - extras = self.recipe_container.update(**extras) - - self._check_compile_recipe() - mod_data = [] - for mod in self.modifiers: - data = mod.pre_initialize_structure(state=self.state, **extras) - logger.debug("Pre-initialized modifier: {}", mod) - if data is not None: - mod_data.append(data) - - self.initialized_structure = True - applied_stage_names = [mod.unique_id for mod in self.modifiers if mod.applied] - self.recipe_container.update_applied_stages(applied_stage_names) - logger.info( - "Compression lifecycle structure pre-initialized for {} modifiers", - len(self.modifiers), - ) - - return mod_data - def initialize(self, **kwargs) -> List[Any]: """ Initialize the compression lifecycle. @@ -115,7 +75,6 @@ def initialize(self, **kwargs) -> List[Any]: :rtype: List[Any] """ logger.debug("Initializing compression lifecycle") - self._check_create_state() extras = self.state.update(**kwargs) extras = self.recipe_container.update(**extras) @@ -229,14 +188,6 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]: return mod_data - def _check_create_state(self): - if self.state is not None: - return - - logger.debug("Creating new State instance for compression lifecycle") - self.state = State() - logger.info("State created for compression lifecycle") - def _check_compile_recipe(self): if not self.recipe_container.check_compile_recipe(): return diff --git a/src/llmcompressor/core/session.py b/src/llmcompressor/core/session.py index 7c489f36f..07eb2dc57 100644 --- a/src/llmcompressor/core/session.py +++ b/src/llmcompressor/core/session.py @@ -65,45 +65,6 @@ def state(self) -> State: """ return self._lifecycle.state - def pre_initialize_structure( - self, - model: Any, - recipe: Union[str, List[str], Recipe, List[Recipe], None] = None, - recipe_stage: Union[str, List[str], None] = None, - recipe_args: Union[Dict[str, Any], List[Dict[str, Any]], None] = None, - **kwargs, - ) -> ModifiedState: - """ - A method to pre-initialize the structure of the model for compression. - This will run the pre-initialize structure method for each modifier in the - session's lifecycle. This will also set the session's state to the - pre-initialized state. Takes care of cases when the model(s) structure - has been previously modified by a modifier. - - :param model: the model to pre-initialize the structure for - :param recipe: the recipe to use for the compression, can be a path to a - recipe file, a raw recipe string, a recipe object, or a list - of recipe objects. - :param recipe_stage: the stage to use for the compression - :param recipe_args: the args to use for overriding the recipe defaults - :return: A ModifiedState instance holding the modified model and modifier_data - after pre-initializing the structure - """ - mod_data = self._lifecycle.pre_initialize_structure( - model=model, - recipe=recipe, - recipe_stage=recipe_stage, - recipe_args=recipe_args, - **kwargs, - ) - - return ModifiedState( - model=self.state.model, - optimizer=None, - loss=None, - modifier_data=mod_data, - ) - def initialize( self, recipe: Union[str, List[str], "Recipe", List["Recipe"], None] = None, diff --git a/src/llmcompressor/core/session_functions.py b/src/llmcompressor/core/session_functions.py index 9a123a030..0dc522c7c 100644 --- a/src/llmcompressor/core/session_functions.py +++ b/src/llmcompressor/core/session_functions.py @@ -11,7 +11,6 @@ "create_session", "active_session", "reset_session", - "pre_initialize_structure", "initialize", "finalize", "apply", @@ -60,16 +59,6 @@ def reset_session(): session._lifecycle.reset() -def pre_initialize_structure(**kwargs): - """ - A method to pre-initialize the structure of the model for the active session - - :param kwargs: the kwargs to pass to the active session's pre-initialize-structure - method - """ - active_session().pre_initialize_structure(**kwargs) - - def initialize( recipe: Union[str, List[str], "Recipe", List["Recipe"], None] = None, recipe_stage: Union[str, List[str], None] = None, diff --git a/src/llmcompressor/modifiers/interface.py b/src/llmcompressor/modifiers/interface.py index e3a3786b4..f1c73c54b 100644 --- a/src/llmcompressor/modifiers/interface.py +++ b/src/llmcompressor/modifiers/interface.py @@ -11,15 +11,6 @@ class ModifierInterface(ABC): Defines the contract that all modifiers must implement """ - @property - @abstractmethod - def initialized_structure(self) -> bool: - """ - :return: True if the modifier structure has been - applied to the model - """ - raise NotImplementedError() - @property @abstractmethod def initialized(self) -> bool: @@ -58,16 +49,6 @@ def calculate_end(self) -> float: """ raise NotImplementedError() - @abstractmethod - def pre_initialize_structure(self, state: State, **kwargs): - """ - Apply the modifier structure to the model - - :param state: The current state of the model - :param kwargs: Additional arguments for the modifier - """ - raise NotImplementedError() - @abstractmethod def initialize(self, state: State, **kwargs): """ diff --git a/src/llmcompressor/modifiers/modifier.py b/src/llmcompressor/modifiers/modifier.py index 65b4a4029..4092cc3de 100644 --- a/src/llmcompressor/modifiers/modifier.py +++ b/src/llmcompressor/modifiers/modifier.py @@ -29,20 +29,11 @@ class Modifier(ModifierInterface, HooksMixin): end: Optional[float] = None update: Optional[float] = None - initialized_structure_: bool = False initialized_: bool = False finalized_: bool = False started_: bool = False ended_: bool = False - @property - def initialized_structure(self) -> bool: - """ - :return: True if the modifier structure has been - applied to the model - """ - return self.initialized_structure_ - @property def initialized(self) -> bool: """ @@ -78,15 +69,6 @@ def calculate_end(self) -> float: """ return self.end if self.end is not None else -1 - def pre_initialize_structure(self, state: State, **kwargs): - """ - :param state: The current state of the model - :param kwargs: Additional arguments for initializing the structure - of the model in question - """ - self.on_initialize_structure(state, **kwargs) - self.initialized_structure_ = True - def initialize(self, state: State, **kwargs): """ Initialize the modifier for the given model and state. @@ -221,19 +203,6 @@ def should_end(self, event: Event): return self.end is not None and current >= self.end - def on_initialize_structure(self, state: State, **kwargs): - """ - on_initialize_structure is called before the model is initialized - with the modifier structure. - - TODO: Depreciate this function as part of the lifecycle - - :param state: The current state of the model - :param kwargs: Additional arguments for initializing the structure - of the model in question - """ - pass - @abstractmethod def on_initialize(self, state: State, **kwargs) -> bool: """ diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 65e1c90e0..b68be2c59 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -62,9 +62,8 @@ class GPTQModifier(Modifier, HooksMixin): | actorder: False Lifecycle: - - on_initialize_structure - - _build_quant_modifier - on_initialize + - _build_quant_modifier - register_hook(module, compress_module, "forward") - run_sequential / run_layer_sequential / run_basic - make_empty_hessian @@ -141,16 +140,16 @@ def validate_sequential_update(cls, value: bool) -> bool: return True - def on_initialize_structure(self, state: State, **kwargs): + def _maybe_build_quant_modifier(self, model: torch.nn.Module): """ Check the model's quantization state matches that expected by this modifier, adding a default quantization scheme if needed - TODO: Depreciate and fold into `on_initialize` + # TODO: build modifier during recipe validation :param state: session state storing input model and calibration data """ - quantization_already_active = qat_active(state.model) + quantization_already_active = qat_active(model) if isinstance(self.quantize, bool): if not self.quantize and quantization_already_active: logger.warning( @@ -191,18 +190,15 @@ def on_initialize_structure(self, state: State, **kwargs): self._build_quant_modifier_from_dict(self.quantize) self.quantize = True - if self._quantization_modifier: - self._quantization_modifier.on_initialize_structure(state, **kwargs) - def on_initialize(self, state: State, **kwargs) -> bool: """ Initialize and run the GPTQ algorithm on the current state :param state: session state storing input model and calibration data """ - # initialize quantization modifier - if not self.initialized_structure_: - self.on_initialize_structure(state, **kwargs) + # build quantization modifier + self._maybe_build_quant_modifier(state.model) + if self._quantization_modifier: self._quantization_modifier.initialize(state, **kwargs) if not self.quantize: diff --git a/src/llmcompressor/modifiers/stage.py b/src/llmcompressor/modifiers/stage.py index 7e63245b6..fe773bcb5 100644 --- a/src/llmcompressor/modifiers/stage.py +++ b/src/llmcompressor/modifiers/stage.py @@ -19,7 +19,7 @@ class StageModifiers(ModifierInterface, BaseModel): :param index: The index of the stage, if applicable :param group: The group name of the stage, if applicable :param applied: Flag for indicating if this stage has has already been - applied to the model, through structure initialization or finalization + applied to the model through finalization """ modifiers: List["Modifier"] = Field(default_factory=list) @@ -27,14 +27,6 @@ class StageModifiers(ModifierInterface, BaseModel): group: Optional[str] = None applied: bool = False - @property - def initialized_structure(self) -> bool: - """ - :return: True if any of the stage modifiers have initialized structure, - False otherwise - """ - return any(mod.initialized_structure for mod in self.modifiers) - @property def initialized(self) -> bool: """ @@ -93,20 +85,6 @@ def calculate_end(self) -> float: """ return max(mod.calculate_end() for mod in self.modifiers) - def pre_initialize_structure(self, state: "State", **kwargs): - """ - Pre initialize the structure for all stage modifiers mark the stage applied - - :param state: The current state of the training - :param kwargs: Additional kwargs to pass to the modifier(s) - pre_initialize_structure method - """ - for modifier in self.modifiers: - modifier.pre_initialize_structure(state, **kwargs) - - self.applied = True - state.loggers.system.info(tag="stage", string="Model structure initialized") - def initialize(self, state: "State", **kwargs): """ Initialize all the stage modifiers diff --git a/src/llmcompressor/pytorch/model_load/helpers.py b/src/llmcompressor/pytorch/model_load/helpers.py index 5ddc7ebd5..ce6b186f0 100644 --- a/src/llmcompressor/pytorch/model_load/helpers.py +++ b/src/llmcompressor/pytorch/model_load/helpers.py @@ -7,7 +7,7 @@ from safetensors import safe_open from torch.nn import Module -from llmcompressor.core import active_session, create_session, pre_initialize_structure +from llmcompressor.core import active_session, create_session from llmcompressor.typing import Processor COMPLETED_STAGES_FILENAME = "completed_stages.json" @@ -33,7 +33,6 @@ def initialize_recipe(model: Module, recipe_path: str): """ if not active_session(): create_session() - pre_initialize_structure(model=model, recipe=recipe_path) # no need to reload if no recipe was applied if recipe_path is None: diff --git a/src/llmcompressor/transformers/finetune/runner.py b/src/llmcompressor/transformers/finetune/runner.py index 769b84248..1a8b01a3d 100644 --- a/src/llmcompressor/transformers/finetune/runner.py +++ b/src/llmcompressor/transformers/finetune/runner.py @@ -238,12 +238,6 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None): "the stage name." ) - # just load structure if stage has already applied - if stage_name in completed_stages: - self.trainer.initialize_structure(stage=stage) - self.trainer.accelerator.wait_for_everyone() - continue - # setup checkpoint dir, TODO: this should be optional self._output_dir = os.path.join( self.parent_output_dir, "stage_" + stage_name diff --git a/src/llmcompressor/transformers/finetune/session_mixin.py b/src/llmcompressor/transformers/finetune/session_mixin.py index a44b08b4f..9b812ff27 100644 --- a/src/llmcompressor/transformers/finetune/session_mixin.py +++ b/src/llmcompressor/transformers/finetune/session_mixin.py @@ -18,7 +18,6 @@ create_session, finalize, initialize, - pre_initialize_structure, ) from llmcompressor.metrics import LoggerManager from llmcompressor.modifiers.distillation.utils.pytorch.model_wrapper import ( @@ -167,26 +166,6 @@ def initialize_session( torch.cuda.empty_cache() - def initialize_structure(self, stage: Optional[str] = None): - """ - Initialize any recipe structural changes such as quantization on the model, - return immediately if session has already been initialized - - :param stage: Optional stage of recipe to run, or None to run all stages - """ - session = active_session() - if session.lifecycle.initialized_: - return False - - pre_initialize_structure( - model=self.model, - recipe=self.recipe, - recipe_stage=stage, - recipe_args=self.recipe_args, - ) - logger.info(f"Initialized LLM Compressor structure from recipe {self.recipe}") - torch.cuda.empty_cache() - def finalize_session(self): """ Wrap up training by finalizing all modifiers initialized in the current session @@ -399,14 +378,12 @@ def train(self, *args, stage: Optional[str] = None, **kwargs): def evaluate(self, *args, **kwargs): """ Run a sparsification evaluation cycle. - Runs initialize_structure for the sparse session before calling - super().evaluate() and finalization of the session after. :param args: positional args to pass to super().evaluate() :param kwargs: keyword args to pass to super().evaluate() :return: the output from super.evaluate() """ - self.initialize_structure() + # TODO remove output = super().evaluate(*args, **kwargs) self.finalize_session() @@ -416,14 +393,13 @@ def evaluate(self, *args, **kwargs): def predict(self, *args, **kwargs): """ Run a sparsification prediction cycle. - Runs initialize_structure for the sparse session before calling - super().predict() and finalization of the session after. :param args: positional args to pass to super().predict() :param kwargs: keyword args to pass to super().predict() :return: the output from super.predict() """ - self.initialize_structure() + # TODO remove + output = super().predict(*args, **kwargs) self.finalize_session() diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index d79d8cbbe..9f4f81685 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -39,10 +39,9 @@ RecipeArguments, TrainingArguments, ) -from llmcompressor.core import pre_initialize_structure, reset_session +from llmcompressor.core import reset_session from llmcompressor.pytorch.model_load.helpers import ( fallback_to_cpu, - get_session_model, initialize_recipe, parse_dtype, ) @@ -384,8 +383,6 @@ def main( if isinstance(processor, str) or processor is None: processor = initialize_processor_from_path(model_args, model, teacher) - pre_initialize_structure(model=model) - # initialize session manager initialize_recipe(model, None) @@ -403,7 +400,7 @@ def main( calib_dataset = stage_runner.get_dataset_split("calibration") trainer = Trainer( - model_init=get_session_model, + model_init=lambda: model, teacher=teacher, recipe=recipe_args.recipe, recipe_args=recipe_args.recipe_args, diff --git a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py index c1f0cb425..6d50bfff6 100644 --- a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py @@ -71,7 +71,7 @@ def test_create_default_quant_modifier(self): assert modifier._quantization_modifier is None testing_harness = LifecyleTestingHarness(model=LinearNet()) - modifier.on_initialize_structure(testing_harness.get_state()) + modifier._maybe_build_quant_modifier(testing_harness.get_state().model) assert modifier.quantize assert isinstance(modifier._quantization_modifier, QuantizationModifier) modifier._quantization_modifier.create_init_config() @@ -105,10 +105,6 @@ def test_set_quant_if_modifer_already_exists(self): modifier = GPTQModifier(block_size=128) assert not modifier._quantization_modifier - - modifier.on_initialize_structure(testing_harness.get_state()) - # since quantization modifier is already applied, quantization must be set in - # GPTQ assert modifier.quantize @@ -142,7 +138,7 @@ def test_set_quant_in_gptq(self): assert modifier._quantization_modifier is None testing_harness = LifecyleTestingHarness(model=LinearNet()) - modifier.on_initialize_structure(testing_harness.get_state()) + modifier._maybe_build_quant_modifier(testing_harness.get_state().model) assert modifier.quantize self.assertIsInstance(modifier._quantization_modifier, QuantizationModifier) diff --git a/tests/llmcompressor/transformers/finetune/test_session_mixin.py b/tests/llmcompressor/transformers/finetune/test_session_mixin.py index 93bd74cd1..f10c0ed51 100644 --- a/tests/llmcompressor/transformers/finetune/test_session_mixin.py +++ b/tests/llmcompressor/transformers/finetune/test_session_mixin.py @@ -64,5 +64,4 @@ def test_mixin_session_init(mixin_trainer): mixin_trainer.initialize_session(epoch=0.0, checkpoint=None) session = active_session() - assert not session.lifecycle.initialized_structure assert session.lifecycle.initialized_ From 3bff7d1b07a460ba34a27c1b8631c7df5e785317 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 17 Feb 2025 15:17:46 -0500 Subject: [PATCH 02/22] remove preinit event Signed-off-by: Kyle Sayers --- src/llmcompressor/core/events/event.py | 31 -------------------------- 1 file changed, 31 deletions(-) diff --git a/src/llmcompressor/core/events/event.py b/src/llmcompressor/core/events/event.py index 9d5d48d63..89eb780c8 100644 --- a/src/llmcompressor/core/events/event.py +++ b/src/llmcompressor/core/events/event.py @@ -27,7 +27,6 @@ class EventType(Enum): The purpose of each EventType is to trigger the corresponding modifier callback during training or post training pipelines. - :param PRE_INIT: Event type for pre-initialization. :param INITIALIZE: Event type for initialization. :param FINALIZE: Event type for finalization. :param BATCH_START: Event type for the start of a batch. @@ -38,7 +37,6 @@ class EventType(Enum): """ # training lifecycle - PRE_INIT = "pre_init" INITIALIZE = "initialize" FINALIZE = "finalize" @@ -51,35 +49,6 @@ class EventType(Enum): OPTIM_PRE_STEP = "optim_pre_step" OPTIM_POST_STEP = "optim_post_step" - def order(self) -> int: - """ - Returns the priority order of the current EventType. - Lower values have higher priority. - - :raises ValueError: if the event type is invalid. - :return: The order of the event type, lower has higher priority. - :rtype: int - """ - if self == EventType.PRE_INIT: - return 0 - elif self == EventType.INITIALIZE: - return 10 - elif self == EventType.FINALIZE: - return 20 - elif self == EventType.BATCH_START: - return 100 - elif self == EventType.LOSS_CALCULATED: - return 110 - elif self == EventType.OPTIM_PRE_STEP: - return 120 - elif self == EventType.OPTIM_POST_STEP: - return 130 - elif self == EventType.BATCH_END: - return 140 - else: - logger.error("Invalid event type: {}", self) - raise ValueError(f"Invalid event type {self}") - @dataclass class Event: From 14e47a52471ae42ea245c70be1c5ca75528a6bd2 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 17 Feb 2025 15:35:43 -0500 Subject: [PATCH 03/22] remove order test Signed-off-by: Kyle Sayers --- tests/unit/core/events/test_event.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/unit/core/events/test_event.py b/tests/unit/core/events/test_event.py index 06f78ec53..de18dcb28 100644 --- a/tests/unit/core/events/test_event.py +++ b/tests/unit/core/events/test_event.py @@ -3,13 +3,6 @@ from llmcompressor.core import Event, EventType -@pytest.mark.smoke -def test_event_type_order(): - assert EventType.PRE_INIT.order() == 0 - assert EventType.INITIALIZE.order() == 10 - assert EventType.FINALIZE.order() == 20 - - @pytest.mark.smoke def test_event_epoch_based(): event = Event(steps_per_epoch=10) From 6b882bb0f8ca1345b092e55b25598e9d38d72d7e Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 17 Feb 2025 19:40:15 -0500 Subject: [PATCH 04/22] consolodate saving Signed-off-by: Kyle Sayers --- .../pytorch/model_load/helpers.py | 68 ++---------- .../transformers/finetune/runner.py | 9 +- .../transformers/finetune/session_mixin.py | 44 ++------ .../transformers/finetune/text_generation.py | 18 ++-- .../compressed_tensors_utils.py | 44 ++++---- src/llmcompressor/utils/fsdp/helpers.py | 101 +----------------- .../obcq/test_consecutive_runs.py | 7 +- 7 files changed, 55 insertions(+), 236 deletions(-) diff --git a/src/llmcompressor/pytorch/model_load/helpers.py b/src/llmcompressor/pytorch/model_load/helpers.py index ce6b186f0..19bdaca62 100644 --- a/src/llmcompressor/pytorch/model_load/helpers.py +++ b/src/llmcompressor/pytorch/model_load/helpers.py @@ -6,15 +6,14 @@ from loguru import logger from safetensors import safe_open from torch.nn import Module +from transformers import PreTrainedModel -from llmcompressor.core import active_session, create_session +from llmcompressor.core import active_session from llmcompressor.typing import Processor COMPLETED_STAGES_FILENAME = "completed_stages.json" __all__ = [ - "initialize_recipe", - "save_model_and_recipe", "copy_python_files_from_model_cache", "fallback_to_cpu", "parse_dtype", @@ -24,67 +23,22 @@ ] -def initialize_recipe(model: Module, recipe_path: str): - """ - Initializes a recipe that has been previously applied to the model - - :param model: PyTorch model to apply structure to - :param recipe_path: path to recipe to apply to the model - """ - if not active_session(): - create_session() - - # no need to reload if no recipe was applied - if recipe_path is None: - return - - session = active_session() - num_stages = len(session.lifecycle.recipe_container.compiled_recipe.stages) - msg = ( - "an unstaged recipe" - if num_stages == 1 - else f"a staged recipe with {num_stages} stages" - ) - logger.info(f"Applied {msg} to the model") - - -def save_model_and_recipe( - model: Module, +def save_checkpoint( save_path: str, - processor: Optional[Processor] = None, - save_safetensors: bool = False, - save_compressed: bool = False, + model: PreTrainedModel, + processor: Processor, + save_safetensors: bool = True, + save_compressed: bool = True, ): - """ - Save a model, processor and the currently loaded recipe to file - - :param model: pytorch model to save - :param save_path: path to save output to - :param processor: model processor or tokenizer to save - :param save_safetensors: whether to save as safetensors or pickle (bin) - :param save_compressed: whether to compress sparse weights on disk - """ - # avoid circular import - from llmcompressor.transformers.utils.helpers import RECIPE_FILE_NAME - + # saving the model also saves the recipe model.save_pretrained( - save_path, save_compressed=save_compressed, safe_serialization=save_safetensors + save_path, + save_safetensors=save_safetensors, + save_compressed=save_compressed, ) - if processor is not None: processor.save_pretrained(save_path) - logger.info("Saving output to {}".format(os.path.abspath(save_path))) - - recipe_path = os.path.join(save_path, RECIPE_FILE_NAME) - session = active_session() - recipe_yaml_str = session.get_serialized_recipe() - with open(recipe_path, "w") as fp: - fp.write(recipe_yaml_str) - - # copy python files from cache dir to save_path if any - copy_python_files_from_model_cache(model, save_path) - def fallback_to_cpu(device: str) -> str: """ diff --git a/src/llmcompressor/transformers/finetune/runner.py b/src/llmcompressor/transformers/finetune/runner.py index 1a8b01a3d..bf51cc4bb 100644 --- a/src/llmcompressor/transformers/finetune/runner.py +++ b/src/llmcompressor/transformers/finetune/runner.py @@ -17,6 +17,7 @@ from llmcompressor.pytorch.model_load.helpers import ( get_completed_stages, get_session_model, + save_checkpoint, save_completed_stages, ) from llmcompressor.pytorch.utils import tensors_to_device @@ -27,7 +28,7 @@ make_dataset_splits, ) from llmcompressor.typing import Processor -from llmcompressor.utils.fsdp.helpers import is_fsdp_model, save_model_and_recipe +from llmcompressor.utils.fsdp.helpers import is_fsdp_model class StageRunner: @@ -258,14 +259,16 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None): if ( self._training_args.output_dir != TrainingArguments.__dataclass_fields__["output_dir"].default + and self.trainer.accelerator.is_main_process ): - save_model_and_recipe( - model=self.trainer.model, + save_checkpoint( save_path=self._output_dir, + model=self.trainer.model, processor=self.processor, save_safetensors=self._training_args.save_safetensors, save_compressed=self._model_args.save_compressed, ) + self.trainer.accelerator.is_main_process.wait_for_everyone() # save stage to checkpoint dir if self.trainer.accelerator.is_main_process: diff --git a/src/llmcompressor/transformers/finetune/session_mixin.py b/src/llmcompressor/transformers/finetune/session_mixin.py index 9b812ff27..97c77c720 100644 --- a/src/llmcompressor/transformers/finetune/session_mixin.py +++ b/src/llmcompressor/transformers/finetune/session_mixin.py @@ -23,15 +23,13 @@ from llmcompressor.modifiers.distillation.utils.pytorch.model_wrapper import ( KDModelWrapper, ) -from llmcompressor.pytorch.model_load.helpers import get_session_model +from llmcompressor.pytorch.model_load.helpers import get_session_model, save_checkpoint from llmcompressor.pytorch.utils import ModuleSparsificationInfo -from llmcompressor.transformers import RECIPE_FILE_NAME from llmcompressor.transformers.finetune.callbacks import ( DisableHalfPrecisionCallback, TrainingLoopCallbacks, ) from llmcompressor.utils.fsdp.context import summon_full_params_context -from llmcompressor.utils.fsdp.helpers import is_fsdp_model, save_pretrained_fsdp from llmcompressor.utils.pytorch import qat_active if TYPE_CHECKING: @@ -445,44 +443,18 @@ def save_model(self, output_dir: str, _internal_call=False, _is_oneshot=False): # knowledge distillation requires making wrappers transparent during if isinstance(self.model, KDModelWrapper): - self.model.prepare_for_save() + self.model.prepare_for_save() # TODO: move to finalize - if not is_fsdp_model(self.model): - self.model.save_pretrained( + # save checkpoint + if self.accelerator.is_main_process: + processor = getattr(self, "processing_class", self.tokenizer) + save_checkpoint( output_dir, - save_compressed=self.model_args.save_compressed, - safe_serialization=self.args.save_safetensors, - ) - else: # FSDP model - save_pretrained_fsdp( model=self.model, - accelerator=self.accelerator, - output_dir=output_dir, + processor=processor, + safe_serialization=self.args.save_safetensors, save_compressed=self.model_args.save_compressed, - save_safetensors=self.metadata.get("save_safetensors", False), ) - - self.save_state() - processor = getattr(self, "processing_class", self.tokenizer) - if processor is not None: - processor.save_pretrained(output_dir) - - if not self.recipe: - return - - if self.accelerator.is_main_process: - # save recipe, will contain modifiers from the model's original recipe as - # well as those added from self.recipe - recipe_path = os.path.join(output_dir, RECIPE_FILE_NAME) - session = active_session() - recipe_yaml_str = session.get_serialized_recipe() - with open(recipe_path, "w") as fp: - fp.write(recipe_yaml_str) - - logger.info( - f"Saved LLM Compressor recipe with model state to {recipe_path}" - ) - self.accelerator.wait_for_everyone() if isinstance(self.model, KDModelWrapper): diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index 9f4f81685..f7cb57c32 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -42,8 +42,8 @@ from llmcompressor.core import reset_session from llmcompressor.pytorch.model_load.helpers import ( fallback_to_cpu, - initialize_recipe, parse_dtype, + save_checkpoint, ) from llmcompressor.recipe import Recipe, StageRunType from llmcompressor.transformers.finetune.runner import StageRunner @@ -383,9 +383,6 @@ def main( if isinstance(processor, str) or processor is None: processor = initialize_processor_from_path(model_args, model, teacher) - # initialize session manager - initialize_recipe(model, None) - # Load datasets stage_runner = StageRunner( model_args=model_args, @@ -452,16 +449,19 @@ def main( stage_runner.predict() # save if model was provided as a string or custom output_dir was set - if isinstance(model_args.model, str) or ( training_args.output_dir != TrainingArguments.__dataclass_fields__["output_dir"].default + and trainer.accelerator.is_main_process() ): - model.save_pretrained( - training_args.output_dir, save_compressed=model_args.save_compressed + save_checkpoint( + save_path=training_args.output_dir, + model=model, + processor=processor, + save_safetensors=True, + save_compressed=model_args.save_compressed, ) - if processor is not None: - processor.save_pretrained(training_args.output_dir) + trainer.accelerator.wait_for_everyone() # Clean up the CompressionSession before exit if requested if recipe_args.clear_sparse_session: diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index ec9951f6a..7994c786f 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -16,6 +16,7 @@ ) from loguru import logger from safetensors.torch import storage_ptr +from transformers import PreTrainedModel from llmcompressor.core import active_session from llmcompressor.pytorch.model_load.helpers import copy_python_files_from_model_cache @@ -26,11 +27,8 @@ SparsityConfigMetadata, ) from llmcompressor.transformers.utils import RECIPE_FILE_NAME +from llmcompressor.transformers.utils.helpers import infer_recipe_from_model_path from llmcompressor.typing import Processor -from llmcompressor.utils.fsdp.helpers import ( - find_and_move_state_dicts_to_cpu, - unwrap_and_export_model, -) __all__ = ["modify_save_pretrained", "modify_fsdp_model_save_pretrained"] @@ -54,6 +52,7 @@ def save_pretrained_compressed(save_pretrained_method): @wraps(original_save_pretrained) def save_pretrained_wrapper( save_directory: str, + save_compressed: bool = True, **kwargs, ): """ @@ -72,24 +71,7 @@ def save_pretrained_wrapper( saving a model in dense format :param kwargs: additional kwargs to pass on to model.save_pretrained """ - try: - trainer.save_model(output_dir=save_directory, _is_oneshot=True) - except AssertionError: - # fallback to this in the case of quantization - unwrap_and_export_model( - model=trainer.model, - accelerator=trainer.accelerator, - output_dir=save_directory, - processor=processor, - ) - # only allow the main process move the state - # dicts to cpu - if trainer.accelerator.is_main_process: - # assuming quantization is the last step - # we no longer need the original model - # and can safely delete it to save memory - del trainer.model - find_and_move_state_dicts_to_cpu(save_directory) + raise NotImplementedError("") save_pretrained_wrapper._overriden = True return save_pretrained_wrapper @@ -100,7 +82,7 @@ def save_pretrained_wrapper( ) -def modify_save_pretrained(model: torch.nn.Module): +def modify_save_pretrained(model: PreTrainedModel): """ Overrides a PreTrainedModel's save_pretrained() method with a wrapped version that supports compression @@ -124,6 +106,7 @@ def save_pretrained_wrapper( sparsity_config: Optional[SparsityCompressionConfig] = None, quantization_format: Optional[str] = None, save_compressed: bool = True, + safe_serialization: bool = True, skip_compression_stats: bool = False, disable_sparse_compression: bool = False, **kwargs, @@ -189,13 +172,22 @@ def skip(*args, **kwargs): # make sure we're on the main process when saving if state_dict is not None and len(state_dict) > 0: compressed_state_dict = compressor.compress(model, state_dict) - - kwargs["safe_serialization"] = kwargs.get("safe_serialization", True) original_save_pretrained.__get__(model, model_class)( - save_directory, state_dict=compressed_state_dict, **kwargs + save_directory, + state_dict=compressed_state_dict, + safe_serialization=safe_serialization, + **kwargs, ) compressor.update_config(save_directory) + # save recipe + existing_recipe = infer_recipe_from_model_path( + model_path=model.name_or_path + ) + recipe_container = active_session().lifecycle.recipe_container + recipe_container.update(recipe=existing_recipe) + recipe_container.check_compile_recipe() + recipe_path = os.path.join(save_directory, RECIPE_FILE_NAME) session = active_session() diff --git a/src/llmcompressor/utils/fsdp/helpers.py b/src/llmcompressor/utils/fsdp/helpers.py index 3a3248fa5..51c08010b 100644 --- a/src/llmcompressor/utils/fsdp/helpers.py +++ b/src/llmcompressor/utils/fsdp/helpers.py @@ -1,34 +1,20 @@ import operator -from pathlib import Path from typing import Optional -from loguru import logger - try: - from torch.distributed.fsdp import ( - FullStateDictConfig, - FullyShardedDataParallel, - StateDictType, - ) + from torch.distributed.fsdp import FullyShardedDataParallel except ImportError: FullyShardedDataParallel = None -import torch from torch.nn import Module from llmcompressor.core.state import State -from llmcompressor.pytorch.model_load.helpers import save_model_and_recipe -from llmcompressor.typing import Processor -from llmcompressor.utils.pytorch import set_layer __all__ = [ "is_fsdp_model", "maybe_get_wrapped", "set_wrapped_model", - "unwrap_and_export_model", - "save_pretrained_fsdp", "get_fsdp_parent", - "find_and_move_state_dicts_to_cpu", ] @@ -72,91 +58,6 @@ def set_wrapped_model(state: State, wrapped_model: Module): state.model = wrapped_model -def unwrap_and_export_model(model, accelerator, output_dir: str, processor: Processor): - """ - Recursively unwraps an FSDP model, then saves the unwrapped model and the - currently active recipe to disk - - :param model: model to unwrap - :param accelerator: Accelerator instance used to perform unwrapping - :param output_dir: where to save output model - :param processor: processor used by the model - """ - full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - with FullyShardedDataParallel.state_dict_type( - model, - StateDictType.FULL_STATE_DICT, - full_state_dict_config, - ): - unwrapped_model = accelerator.unwrap_model(model) - for name, module in unwrapped_model.named_modules(): - if isinstance(module, FullyShardedDataParallel): - set_layer(name, accelerator.unwrap_model(module), unwrapped_model) - - save_model_and_recipe( - model=unwrapped_model, - save_path=output_dir, - processor=processor, - ) - - -def find_and_move_state_dicts_to_cpu(output_dir: str): - """ - Looks for state dicts in the output directory and overwrites them - with cpu state dicts. - - this is needed for quantized models trained with FSDP as the state dict - contains device information, which can cause issues when loading the model - using transformers AutoModel.from_pretrained(...) if the device information - is not removed, assumes the state dicts are named pytorch_model*.bin - """ - - for model_file in Path(output_dir).rglob("pytorch_model*.bin"): - loaded_dict = torch.load(model_file) - for key, value in loaded_dict.items(): - if isinstance(value, torch.Tensor): - loaded_dict[key] = value.cpu() - - torch.save(loaded_dict, model_file) - logger.info(f"Moved state dict {model_file} to cpu") - - -def save_pretrained_fsdp( - model, - accelerator, - output_dir, - save_safetensors: bool = True, - save_compressed: bool = False, -): - full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - """ - Gathers the full FSDP state dict of the model onto rank0 GPU, then uses it to save - the pretrained FSDP model to disk - - :param model: model to save - :param accelerator: Accelerator instance used to perform unwrapping - :param output_dir: where to save output model - :param save_safetensors: True to safe in safetensors format, otherwise .bin - :param save_compressed: whether to compress sparse weights on disk - """ - with FullyShardedDataParallel.state_dict_type( - model, StateDictType.FULL_STATE_DICT, full_state_dict_config - ): - state_dict = accelerator.get_state_dict(model, unwrap=False) - - if accelerator.is_main_process: - accelerator.unwrap_model(model).save_pretrained( - output_dir, - is_main_process=accelerator.is_main_process, - save_function=accelerator.save, - state_dict=state_dict, - save_compressed=save_compressed, - safe_serialization=save_safetensors, - ) - - accelerator.wait_for_everyone() - - def get_fsdp_parent(layer_name: str, model: Module) -> Optional[Module]: """ Gets the closest parent of layer_name that is wrapped by FSDP. If no FSDP wrapper diff --git a/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py b/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py index 16c9003be..cafe00a42 100644 --- a/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py +++ b/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py @@ -9,7 +9,6 @@ from transformers.utils.quantization_config import CompressedTensorsConfig from llmcompressor.transformers.utils import is_model_ct_quantized_from_path -from llmcompressor.transformers.utils.helpers import infer_recipe_from_model_path from tests.testing_utils import parse_params, requires_gpu CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/obcq/obcq_configs/consec_runs" @@ -27,7 +26,6 @@ def _test_consecutive_runs( import math from llmcompressor.core import active_session - from llmcompressor.pytorch.model_load.helpers import initialize_recipe from llmcompressor.pytorch.utils.helpers import tensor_sparsity from llmcompressor.transformers import oneshot from llmcompressor.utils.pytorch import qat_active @@ -61,9 +59,8 @@ def _test_consecutive_runs( self.assertEqual(len(stages), 1) session.reset() - recipe = infer_recipe_from_model_path(model_path=self.output_first) - if recipe: - initialize_recipe(model=first_model, recipe_path=recipe) + # reuse the same session, do not construct a new one + # TODO: add test which uses new sessions? # reload saved model and up sparsity to 0.7 oneshot( From bb35a74feeb4a0d4aedecb53328e4ac2142770f1 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 17 Feb 2025 19:44:08 -0500 Subject: [PATCH 05/22] typos Signed-off-by: Kyle Sayers --- src/llmcompressor/transformers/finetune/runner.py | 2 +- src/llmcompressor/transformers/finetune/text_generation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/transformers/finetune/runner.py b/src/llmcompressor/transformers/finetune/runner.py index bf51cc4bb..aaed35d84 100644 --- a/src/llmcompressor/transformers/finetune/runner.py +++ b/src/llmcompressor/transformers/finetune/runner.py @@ -268,7 +268,7 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None): save_safetensors=self._training_args.save_safetensors, save_compressed=self._model_args.save_compressed, ) - self.trainer.accelerator.is_main_process.wait_for_everyone() + self.trainer.accelerator.wait_for_everyone() # save stage to checkpoint dir if self.trainer.accelerator.is_main_process: diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index f7cb57c32..ab48655ce 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -452,7 +452,7 @@ def main( if isinstance(model_args.model, str) or ( training_args.output_dir != TrainingArguments.__dataclass_fields__["output_dir"].default - and trainer.accelerator.is_main_process() + and trainer.accelerator.is_main_process ): save_checkpoint( save_path=training_args.output_dir, From 71903ff8162c935c952aa16695d6f6539abe9314 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 17 Feb 2025 19:49:43 -0500 Subject: [PATCH 06/22] add todos Signed-off-by: Kyle Sayers --- .../transformers/sparsification/compressed_tensors_utils.py | 2 +- tests/llmcompressor/transformers/obcq/test_consecutive_runs.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index 7994c786f..6d8738d4d 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -185,7 +185,7 @@ def skip(*args, **kwargs): model_path=model.name_or_path ) recipe_container = active_session().lifecycle.recipe_container - recipe_container.update(recipe=existing_recipe) + recipe_container.update(recipe=existing_recipe) # TODO: append to beginning, not back recipe_container.check_compile_recipe() recipe_path = os.path.join(save_directory, RECIPE_FILE_NAME) diff --git a/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py b/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py index cafe00a42..70340c178 100644 --- a/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py +++ b/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py @@ -95,6 +95,7 @@ def _test_consecutive_runs( self.assertEqual(len(stage_keys), 2) self.assertIn("test_stage_0", stage_keys) self.assertIn("test_stage_1", stage_keys) + # TODO: test order def tearDown(self): shutil.rmtree(self.output) From d39d37546ee5d2e3ed1d4e46541aa270957bbc48 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 17 Feb 2025 21:08:51 -0500 Subject: [PATCH 07/22] dreggs, style Signed-off-by: Kyle Sayers --- src/llmcompressor/core/lifecycle.py | 2 +- src/llmcompressor/core/session_functions.py | 2 +- .../transformers/sparsification/compressed_tensors_utils.py | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/core/lifecycle.py b/src/llmcompressor/core/lifecycle.py index f41ab15df..cd415ad78 100644 --- a/src/llmcompressor/core/lifecycle.py +++ b/src/llmcompressor/core/lifecycle.py @@ -149,7 +149,7 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]: logger.error("Cannot invoke event after finalizing") raise ValueError("Cannot invoke event after finalizing") - if event_type in [EventType.PRE_INIT, EventType.INITIALIZE, EventType.FINALIZE]: + if event_type in [EventType.INITIALIZE, EventType.FINALIZE]: logger.error( "Cannot invoke {} event. Use the corresponding method instead.", event_type, diff --git a/src/llmcompressor/core/session_functions.py b/src/llmcompressor/core/session_functions.py index 0dc522c7c..5f8fd1a0c 100644 --- a/src/llmcompressor/core/session_functions.py +++ b/src/llmcompressor/core/session_functions.py @@ -202,7 +202,7 @@ def event(cls, event_type: EventType, **kwargs) -> ModifiedState: :param kwargs: additional kwargs to pass to the current session's event method :return: the modified state of the active session after invoking the event """ - if event_type in [EventType.PRE_INIT, EventType.INITIALIZE, EventType.FINALIZE]: + if event_type in [EventType.INITIALIZE, EventType.FINALIZE]: raise ValueError( f"Cannot invoke {event_type} event. " f"Use the corresponding method instead." diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index 6d8738d4d..75243868e 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -185,7 +185,9 @@ def skip(*args, **kwargs): model_path=model.name_or_path ) recipe_container = active_session().lifecycle.recipe_container - recipe_container.update(recipe=existing_recipe) # TODO: append to beginning, not back + recipe_container.update( + recipe=existing_recipe + ) # TODO: append to beginning, not back recipe_container.check_compile_recipe() recipe_path = os.path.join(save_directory, RECIPE_FILE_NAME) From 7cc5a6dc0207a96de47761c2a4eb95db590493e8 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 17 Feb 2025 23:06:43 -0500 Subject: [PATCH 08/22] typo Signed-off-by: Kyle Sayers --- src/llmcompressor/transformers/finetune/session_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/transformers/finetune/session_mixin.py b/src/llmcompressor/transformers/finetune/session_mixin.py index 97c77c720..807d7b8f5 100644 --- a/src/llmcompressor/transformers/finetune/session_mixin.py +++ b/src/llmcompressor/transformers/finetune/session_mixin.py @@ -452,7 +452,7 @@ def save_model(self, output_dir: str, _internal_call=False, _is_oneshot=False): output_dir, model=self.model, processor=processor, - safe_serialization=self.args.save_safetensors, + save_safetensors=self.args.save_safetensors, save_compressed=self.model_args.save_compressed, ) self.accelerator.wait_for_everyone() From 9865fa3e703445962d60d41133a35cf37077b25f Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 17 Feb 2025 23:23:40 -0500 Subject: [PATCH 09/22] adjust typehint Signed-off-by: Kyle Sayers --- src/llmcompressor/core/lifecycle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/core/lifecycle.py b/src/llmcompressor/core/lifecycle.py index cd415ad78..24cff13d1 100644 --- a/src/llmcompressor/core/lifecycle.py +++ b/src/llmcompressor/core/lifecycle.py @@ -38,7 +38,7 @@ class CompressionLifecycle: :type event_lifecycle: Optional[EventLifecycle] """ - state: Optional[State] = field(default_factory=State) + state: State = field(default_factory=State) recipe_container: RecipeContainer = field(default_factory=RecipeContainer) modifiers: List[StageModifiers] = field(default_factory=list) event_lifecycle: Optional[EventLifecycle] = None From 68ce62403983df6101ab16033ad0f96559d44e22 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 18 Feb 2025 14:48:14 -0500 Subject: [PATCH 10/22] allow prepending Signed-off-by: Kyle Sayers --- src/llmcompressor/core/lifecycle.py | 41 +++---- src/llmcompressor/recipe/__init__.py | 5 +- src/llmcompressor/recipe/container.py | 110 +++++++++--------- src/llmcompressor/recipe/recipe.py | 72 +++++++++++- .../compressed_tensors_utils.py | 36 +++--- .../obcq/test_consecutive_runs.py | 5 - 6 files changed, 166 insertions(+), 103 deletions(-) diff --git a/src/llmcompressor/core/lifecycle.py b/src/llmcompressor/core/lifecycle.py index 24cff13d1..3e5542528 100644 --- a/src/llmcompressor/core/lifecycle.py +++ b/src/llmcompressor/core/lifecycle.py @@ -18,7 +18,12 @@ ) from llmcompressor.core.state import State from llmcompressor.modifiers import StageModifiers -from llmcompressor.recipe import RecipeContainer +from llmcompressor.recipe import ( + RecipeArgsInput, + RecipeContainer, + RecipeInput, + RecipeStageInput, +) __all__ = ["CompressionLifecycle"] @@ -66,7 +71,13 @@ def reset(self): self.__init__() logger.info("Compression lifecycle reset") - def initialize(self, **kwargs) -> List[Any]: + def initialize( + self, + recipe: Optional[RecipeInput] = None, + recipe_stage: Optional[RecipeStageInput] = None, + recipe_args: Optional[RecipeArgsInput] = None, + **kwargs, + ) -> List[Any]: """ Initialize the compression lifecycle. @@ -75,14 +86,14 @@ def initialize(self, **kwargs) -> List[Any]: :rtype: List[Any] """ logger.debug("Initializing compression lifecycle") - extras = self.state.update(**kwargs) - extras = self.recipe_container.update(**extras) - - self._check_compile_recipe() + self.state.update(**kwargs) + self.recipe_container.append(recipe, recipe_stage, recipe_args) + self.modifiers = self.recipe_container.get_modifiers() self._set_model_layer_prefix() + mod_data = [] for mod in self.modifiers: - data = mod.initialize(state=self.state, **extras) + data = mod.initialize(state=self.state, **kwargs) logger.debug("Initialized modifier: {}", mod) if data is not None: mod_data.append(data) @@ -188,22 +199,6 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]: return mod_data - def _check_compile_recipe(self): - if not self.recipe_container.check_compile_recipe(): - return - - logger.debug( - "Compiling recipe and creating modifiers for compression lifecycle" - ) - self.modifiers = self.recipe_container.compiled_recipe.create_modifier() - for mod in self.modifiers: - if mod.unique_id in self.recipe_container.applied_stages: - mod.applied = True - logger.info( - "Recipe compiled and {} modifiers created", - len(self.modifiers), - ) - def _check_setup_event_lifecycle(self, event_type: EventType): if self.event_lifecycle is not None: return diff --git a/src/llmcompressor/recipe/__init__.py b/src/llmcompressor/recipe/__init__.py index e02a18b39..bb4df06af 100644 --- a/src/llmcompressor/recipe/__init__.py +++ b/src/llmcompressor/recipe/__init__.py @@ -9,7 +9,7 @@ RecipeMetaData, ) from .modifier import RecipeModifier -from .recipe import Recipe, RecipeTuple +from .recipe import Recipe, RecipeArgsInput, RecipeInput, RecipeStageInput, RecipeTuple from .stage import RecipeStage, StageRunType __all__ = [ @@ -26,4 +26,7 @@ "Recipe", "RecipeTuple", "StageRunType", + "RecipeInput", + "RecipeStageInput", + "RecipeArgsInput", ] diff --git a/src/llmcompressor/recipe/container.py b/src/llmcompressor/recipe/container.py index 5cae0dd2c..90c9c1dad 100644 --- a/src/llmcompressor/recipe/container.py +++ b/src/llmcompressor/recipe/container.py @@ -1,8 +1,14 @@ from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Union +from typing import List, Optional from llmcompressor.modifiers import Modifier -from llmcompressor.recipe.recipe import Recipe, RecipeTuple +from llmcompressor.recipe.recipe import ( + Recipe, + RecipeArgsInput, + RecipeInput, + RecipeStageInput, + RecipeTuple, +) __all__ = ["RecipeContainer"] @@ -22,52 +28,42 @@ class RecipeContainer: recipes: List[RecipeTuple] = field(default_factory=list) applied_stages: List[str] = field(default_factory=list) - def update( + def prepend( self, - recipe: Union[ - str, List[str], Recipe, List[Recipe], Modifier, List[Modifier], None - ] = None, - recipe_stage: Union[str, List[str], List[List[str]], None] = None, - recipe_args: Union[Dict[str, Any], List[Dict[str, Any]], None] = None, - **kwargs, - ) -> Dict: - """ - Update the recipes in the container. If a recipe is provided, it will - reset any existing compiled_recipe in the container. Must call - `check_compile_recipe` to re-compile the recipes into a single compiled_recipe. - If no recipe is provided, does nothing and returns the kwargs. - - Can provide multiple recipes to update the container with: - >>> container = RecipeContainer() - >>> recipe_str_1 = ''' - ... test_stage: - ... pruning_modifiers: - ... ConstantPruningModifier: - ... start: 0.0 - ... end: 2.0 - ... targets: ['re:.*weight'] - ... ''' - >>> recipe_str_2 = ''' - ... test_stage: - ... pruning_modifiers: - ... ConstantPruningModifier: - ... start: 3.0 - ... end: 4.0 - ... targets: ['re:.*weight'] - ... ''' - >>> result = container.update(recipe=[recipe_str_1, recipe_str_2]) - - :param recipe: the recipe to update the container with - :param recipe_stage: the recipe stage to update the container with - :param recipe_args: the recipe args to update the recipe with - :param kwargs: additional kwargs to return - :return: the passed in kwargs - """ - if recipe is None or isinstance(recipe, list) and len(recipe) == 0: - return kwargs + recipe: Optional[RecipeInput] = None, + recipe_stage: Optional[RecipeStageInput] = None, + recipe_args: Optional[RecipeArgsInput] = None, + ): + recipe_tuples = self._prepare_tuples(recipe, recipe_stage, recipe_args) + self.recipes = recipe_tuples + self.recipes + self._check_compile_recipe() + + def append( + self, + recipe: Optional[RecipeInput] = None, + recipe_stage: Optional[RecipeStageInput] = None, + recipe_args: Optional[RecipeArgsInput] = None, + ): + recipe_tuples = self._prepare_tuples(recipe, recipe_stage, recipe_args) + self.recipes = self.recipes + recipe_tuples + self._check_compile_recipe() - self.compiled_recipe = None + def get_modifiers(self) -> List[Modifier]: + if self.compiled_recipe is None: + return [] + return self.compiled_recipe.create_modifier() + + def _prepare_tuples( + self, + recipe: Optional[RecipeInput] = None, + recipe_stage: Optional[RecipeStageInput] = None, + recipe_args: Optional[RecipeArgsInput] = None, + ) -> List[RecipeTuple]: + if recipe is None or (isinstance(recipe, list) and len(recipe) == 0): + return [] + + # prepare recipe if isinstance(recipe, Modifier) or ( isinstance(recipe, list) and all(isinstance(mod, Modifier) for mod in recipe) @@ -77,6 +73,12 @@ def update( if not isinstance(recipe, list): recipe = [recipe] + recipe = [ + Recipe.create_instance(rec) if isinstance(rec, str) else rec + for rec in recipe + ] + + # prepare stage if recipe_stage is None: recipe_stage = [None] * len(recipe) else: @@ -85,22 +87,23 @@ def update( if not isinstance(recipe_stage[0], list): recipe_stage = [recipe_stage] * len(recipe) + # prepare args if recipe_args is None: recipe_args = [{}] * len(recipe) elif not isinstance(recipe_args, list): recipe_args = [recipe_args] * len(recipe) + # validation if len(recipe) != len(recipe_stage) or len(recipe) != len(recipe_args): raise ValueError( "recipe, recipe_stage, and recipe_args must be the same length" ) - for rec, stage, args in zip(recipe, recipe_stage, recipe_args): - if isinstance(rec, str): - rec = Recipe.create_instance(rec) - self.recipes.append(RecipeTuple(rec, stage, args)) - - return kwargs + # create tuples + return [ + RecipeTuple(rec, stage, args) + for rec, stage, args in zip(recipe, recipe_stage, recipe_args) + ] def update_applied_stages(self, new_stages: List[str]): """ @@ -113,7 +116,7 @@ def update_applied_stages(self, new_stages: List[str]): if stage not in self.applied_stages: self.applied_stages.append(stage) - def check_compile_recipe(self) -> bool: + def _check_compile_recipe(self): """ Check if the recipes need to be compiled into a single recipe and compile them if they do. @@ -122,9 +125,6 @@ def check_compile_recipe(self) -> bool: """ if self.compiled_recipe is None and self.recipes: self.compiled_recipe = Recipe.simplify_combine_recipes(self.recipes) - return True - - return False def check_any_recipe_exists(self) -> bool: """ diff --git a/src/llmcompressor/recipe/recipe.py b/src/llmcompressor/recipe/recipe.py index 1e9851ba8..9e2398944 100644 --- a/src/llmcompressor/recipe/recipe.py +++ b/src/llmcompressor/recipe/recipe.py @@ -14,7 +14,13 @@ from llmcompressor.recipe.metadata import RecipeMetaData from llmcompressor.recipe.stage import RecipeStage -__all__ = ["Recipe", "RecipeTuple"] +__all__ = [ + "Recipe", + "RecipeTuple", + "RecipeInput", + "RecipeStageInput", + "RecipeArgsInput", +] class Recipe(RecipeBase): @@ -150,7 +156,7 @@ def create_instance( @staticmethod def simplify_recipe( - recipe: Union["Recipe", "RecipeTuple"], shift: Optional[int] = None + recipe: Union[str, "Recipe", "RecipeTuple"], shift: Optional[int] = None ) -> "Recipe": """ Simplify a RecipeTuple by removing stages that are not in the target_stages @@ -177,6 +183,9 @@ def simplify_recipe( defaults to None (No shift) :return: The simplified Recipe instance """ + if isinstance(recipe, str): + recipe = Recipe.create_instance(recipe) + if isinstance(recipe, Recipe): recipe.evaluate(shift=shift) return recipe @@ -212,7 +221,7 @@ def simplify_recipe( @staticmethod def simplify_combine_recipes( - recipes: List[Union["Recipe", "RecipeTuple"]], + recipes: List[Union[str, "Recipe", "RecipeTuple"]], ) -> "Recipe": """ A method to combine multiple recipes into one recipe @@ -571,6 +580,11 @@ def _get_yaml_dict(self) -> Dict[str, Any]: return yaml_recipe_dict +RecipeInput = Union[str, List[str], Recipe, List[Recipe], Modifier, List[Modifier]] +RecipeStageInput = Union[str, List[str], List[List[str]]] +RecipeArgsInput = Union[Dict[str, Any], List[Dict[str, Any]]] + + @dataclass class RecipeTuple: """ @@ -588,6 +602,58 @@ class RecipeTuple: target_stages: List[str] override_args: Dict[str, Any] + @staticmethod + def from_inputs( + cls, + recipe: Optional[RecipeInput] = None, + recipe_stage: Optional[RecipeStageInput] = None, + recipe_args: Optional[RecipeArgsInput] = None, + ) -> List["RecipeTuple"]: + if recipe is None or recipe == []: + return [] + + # prepare recipe + if isinstance(recipe, Modifier) or ( + isinstance(recipe, list) + and all(isinstance(mod, Modifier) for mod in recipe) + ): + recipe = Recipe.create_instance(recipe) + + if not isinstance(recipe, list): + recipe = [recipe] + + recipe = [ + Recipe.create_instance(rec) if isinstance(rec, str) else rec + for rec in recipe + ] + + # prepare stage + if recipe_stage is None: + recipe_stage = [None] * len(recipe) + else: + if not isinstance(recipe_stage, list): + recipe_stage = [[recipe_stage]] * len(recipe) + if not isinstance(recipe_stage[0], list): + recipe_stage = [recipe_stage] * len(recipe) + + # prepare args + if recipe_args is None: + recipe_args = [{}] * len(recipe) + elif not isinstance(recipe_args, list): + recipe_args = [recipe_args] * len(recipe) + + # validation + if len(recipe) != len(recipe_stage) or len(recipe) != len(recipe_args): + raise ValueError( + "recipe, recipe_stage, and recipe_args must be the same length" + ) + + # create tuples + return [ + cls(rec, stage, args) + for rec, stage, args in zip(recipe, recipe_stage, recipe_args) + ] + def _load_json_or_yaml_string(content: str) -> Dict[str, Any]: # try loading as json first, then yaml diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index 75243868e..c5db8554d 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -20,6 +20,7 @@ from llmcompressor.core import active_session from llmcompressor.pytorch.model_load.helpers import copy_python_files_from_model_cache +from llmcompressor.recipe.recipe import Recipe from llmcompressor.transformers.compression.quantization_format import ( infer_quantization_format, ) @@ -180,22 +181,8 @@ def skip(*args, **kwargs): ) compressor.update_config(save_directory) - # save recipe - existing_recipe = infer_recipe_from_model_path( - model_path=model.name_or_path - ) - recipe_container = active_session().lifecycle.recipe_container - recipe_container.update( - recipe=existing_recipe - ) # TODO: append to beginning, not back - recipe_container.check_compile_recipe() - - recipe_path = os.path.join(save_directory, RECIPE_FILE_NAME) - session = active_session() - - if (recipe_yaml_str := session.get_serialized_recipe()) is not None: - with open(recipe_path, "w") as fp: - fp.write(recipe_yaml_str) + # update existing recipe + update_and_save_recipe(model.name_or_path, save_directory) # copy python files from cache dir to save_path if any copy_python_files_from_model_cache(model, save_directory) @@ -314,3 +301,20 @@ def get_model_compressor( sparsity_config=sparsity_config, quantization_format=quantization_format, ) + + +def update_and_save_recipe(model_path: str, save_directory: str): + recipes_to_save = [] + existing_recipe = infer_recipe_from_model_path(model_path) + if existing_recipe is not None: + recipes_to_save.append(existing_recipe) + + new_recipe = active_session().lifecycle.recipe_container.compiled_recipe + if new_recipe is not None: + recipes_to_save.append(new_recipe) + + recipe = Recipe.simplify_combine_recipes(recipes_to_save) + + # save recipe + recipe_path = os.path.join(save_directory, RECIPE_FILE_NAME) + recipe.yaml(recipe_path) diff --git a/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py b/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py index 70340c178..0986406b4 100644 --- a/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py +++ b/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py @@ -84,11 +84,6 @@ def _test_consecutive_runs( assert math.isclose(layer_0_sparse.item(), 0.7, rel_tol=tolerance) assert qat_active(second_model) - session = active_session() - session_recipe = session.lifecycle.recipe_container.compiled_recipe - stages = [stage.group for stage in session_recipe.stages] - self.assertEqual(len(stages), 2) - recipe_path = self.output_second / "recipe.yaml" recipe_data = yaml.safe_load(recipe_path.read_text()) stage_keys = recipe_data.keys() From 5b7cc03a5bd00d5ef69892755b13557c1ad8c3d5 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 18 Feb 2025 15:03:44 -0500 Subject: [PATCH 11/22] check saved recipe contents Signed-off-by: Kyle Sayers --- .../obcq/test_consecutive_runs.py | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py b/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py index 0986406b4..fd8fdccdc 100644 --- a/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py +++ b/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py @@ -8,6 +8,7 @@ from transformers import AutoModelForCausalLM from transformers.utils.quantization_config import CompressedTensorsConfig +from llmcompressor.recipe import Recipe from llmcompressor.transformers.utils import is_model_ct_quantized_from_path from tests.testing_utils import parse_params, requires_gpu @@ -59,10 +60,7 @@ def _test_consecutive_runs( self.assertEqual(len(stages), 1) session.reset() - # reuse the same session, do not construct a new one - # TODO: add test which uses new sessions? - - # reload saved model and up sparsity to 0.7 + # reload saved model and increase sparsity to 0.7 oneshot( model=self.output_first, dataset=self.dataset, @@ -90,7 +88,24 @@ def _test_consecutive_runs( self.assertEqual(len(stage_keys), 2) self.assertIn("test_stage_0", stage_keys) self.assertIn("test_stage_1", stage_keys) - # TODO: test order + + # check saved modifier names are same + stage0_modifier_names = list( + list(recipe_data["test_stage_0"].values())[0].keys() + ) + exp_stage0_modifier_names = [ + mod.type + for mod in Recipe.create_instance(self.first_recipe).stages[0].modifiers + ] + stage1_modifier_names = list( + list(recipe_data["test_stage_1"].values())[0].keys() + ) + exp_stage1_modifier_names = [ + mod.type + for mod in Recipe.create_instance(self.second_recipe).stages[0].modifiers + ] + self.assertEqual(stage0_modifier_names, exp_stage0_modifier_names) + self.assertEqual(stage1_modifier_names, exp_stage1_modifier_names) def tearDown(self): shutil.rmtree(self.output) From bdc4fa5e8bdb768d0d503153bc9e7ad526a34604 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 18 Feb 2025 15:32:16 -0500 Subject: [PATCH 12/22] consolidate saving paths Signed-off-by: Kyle Sayers --- .../pytorch/model_load/helpers.py | 44 ++------ .../transformers/finetune/runner.py | 10 +- .../transformers/finetune/session_mixin.py | 50 ++------- .../transformers/finetune/text_generation.py | 20 ++-- .../compressed_tensors_utils.py | 101 ++++-------------- 5 files changed, 61 insertions(+), 164 deletions(-) diff --git a/src/llmcompressor/pytorch/model_load/helpers.py b/src/llmcompressor/pytorch/model_load/helpers.py index 5ddc7ebd5..0f3c31df5 100644 --- a/src/llmcompressor/pytorch/model_load/helpers.py +++ b/src/llmcompressor/pytorch/model_load/helpers.py @@ -6,6 +6,7 @@ from loguru import logger from safetensors import safe_open from torch.nn import Module +from transformers import PreTrainedModel from llmcompressor.core import active_session, create_session, pre_initialize_structure from llmcompressor.typing import Processor @@ -13,21 +14,19 @@ COMPLETED_STAGES_FILENAME = "completed_stages.json" __all__ = [ - "initialize_recipe", - "save_model_and_recipe", "copy_python_files_from_model_cache", "fallback_to_cpu", "parse_dtype", "get_session_model", "get_completed_stages", "save_completed_stages", + "save_checkpoint", ] def initialize_recipe(model: Module, recipe_path: str): """ Initializes a recipe that has been previously applied to the model - :param model: PyTorch model to apply structure to :param recipe_path: path to recipe to apply to the model """ @@ -49,43 +48,22 @@ def initialize_recipe(model: Module, recipe_path: str): logger.info(f"Applied {msg} to the model") -def save_model_and_recipe( - model: Module, +def save_checkpoint( save_path: str, - processor: Optional[Processor] = None, - save_safetensors: bool = False, - save_compressed: bool = False, + model: PreTrainedModel, + processor: Processor, + save_safetensors: bool = True, + save_compressed: bool = True, ): - """ - Save a model, processor and the currently loaded recipe to file - - :param model: pytorch model to save - :param save_path: path to save output to - :param processor: model processor or tokenizer to save - :param save_safetensors: whether to save as safetensors or pickle (bin) - :param save_compressed: whether to compress sparse weights on disk - """ - # avoid circular import - from llmcompressor.transformers.utils.helpers import RECIPE_FILE_NAME - + # saving the model also saves the recipe model.save_pretrained( - save_path, save_compressed=save_compressed, safe_serialization=save_safetensors + save_path, + save_safetensors=save_safetensors, + save_compressed=save_compressed, ) - if processor is not None: processor.save_pretrained(save_path) - logger.info("Saving output to {}".format(os.path.abspath(save_path))) - - recipe_path = os.path.join(save_path, RECIPE_FILE_NAME) - session = active_session() - recipe_yaml_str = session.get_serialized_recipe() - with open(recipe_path, "w") as fp: - fp.write(recipe_yaml_str) - - # copy python files from cache dir to save_path if any - copy_python_files_from_model_cache(model, save_path) - def fallback_to_cpu(device: str) -> str: """ diff --git a/src/llmcompressor/transformers/finetune/runner.py b/src/llmcompressor/transformers/finetune/runner.py index 769b84248..92b3d345d 100644 --- a/src/llmcompressor/transformers/finetune/runner.py +++ b/src/llmcompressor/transformers/finetune/runner.py @@ -17,6 +17,7 @@ from llmcompressor.pytorch.model_load.helpers import ( get_completed_stages, get_session_model, + save_checkpoint, save_completed_stages, ) from llmcompressor.pytorch.utils import tensors_to_device @@ -27,7 +28,7 @@ make_dataset_splits, ) from llmcompressor.typing import Processor -from llmcompressor.utils.fsdp.helpers import is_fsdp_model, save_model_and_recipe +from llmcompressor.utils.fsdp.helpers import is_fsdp_model class StageRunner: @@ -261,17 +262,20 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None): self.train(checkpoint=checkpoint, stage=stage_name) checkpoint = None + # save model between stages if ( self._training_args.output_dir != TrainingArguments.__dataclass_fields__["output_dir"].default + and self.trainer.accelerator.is_main_process ): - save_model_and_recipe( - model=self.trainer.model, + save_checkpoint( save_path=self._output_dir, + model=self.trainer.model, processor=self.processor, save_safetensors=self._training_args.save_safetensors, save_compressed=self._model_args.save_compressed, ) + self.trainer.accelerator.wait_for_everyone() # save stage to checkpoint dir if self.trainer.accelerator.is_main_process: diff --git a/src/llmcompressor/transformers/finetune/session_mixin.py b/src/llmcompressor/transformers/finetune/session_mixin.py index e32c64f62..d9eab5d26 100644 --- a/src/llmcompressor/transformers/finetune/session_mixin.py +++ b/src/llmcompressor/transformers/finetune/session_mixin.py @@ -24,15 +24,13 @@ from llmcompressor.modifiers.distillation.utils.pytorch.model_wrapper import ( KDModelWrapper, ) -from llmcompressor.pytorch.model_load.helpers import get_session_model +from llmcompressor.pytorch.model_load.helpers import get_session_model, save_checkpoint from llmcompressor.pytorch.utils import ModuleSparsificationInfo -from llmcompressor.transformers import RECIPE_FILE_NAME from llmcompressor.transformers.finetune.callbacks import ( DisableHalfPrecisionCallback, TrainingLoopCallbacks, ) from llmcompressor.utils.fsdp.context import summon_full_params_context -from llmcompressor.utils.fsdp.helpers import is_fsdp_model, save_pretrained_fsdp from llmcompressor.utils.pytorch import qat_active if TYPE_CHECKING: @@ -66,8 +64,8 @@ class SessionManagerMixIn: def __init__( self, recipe: str, + data_args: "DatasetArguments", model_args: "ModelArguments", - data_args: Optional["DatasetArguments"] = None, teacher: Optional[Union[Module, str]] = None, recipe_args: Optional[Union[Dict[str, Any], str]] = None, **kwargs, @@ -171,7 +169,6 @@ def initialize_structure(self, stage: Optional[str] = None): """ Initialize any recipe structural changes such as quantization on the model, return immediately if session has already been initialized - :param stage: Optional stage of recipe to run, or None to run all stages """ session = active_session() @@ -401,7 +398,6 @@ def evaluate(self, *args, **kwargs): Run a sparsification evaluation cycle. Runs initialize_structure for the sparse session before calling super().evaluate() and finalization of the session after. - :param args: positional args to pass to super().evaluate() :param kwargs: keyword args to pass to super().evaluate() :return: the output from super.evaluate() @@ -418,12 +414,12 @@ def predict(self, *args, **kwargs): Run a sparsification prediction cycle. Runs initialize_structure for the sparse session before calling super().predict() and finalization of the session after. - :param args: positional args to pass to super().predict() :param kwargs: keyword args to pass to super().predict() :return: the output from super.predict() """ self.initialize_structure() + output = super().predict(*args, **kwargs) self.finalize_session() @@ -469,44 +465,18 @@ def save_model(self, output_dir: str, _internal_call=False, _is_oneshot=False): # knowledge distillation requires making wrappers transparent during if isinstance(self.model, KDModelWrapper): - self.model.prepare_for_save() + self.model.prepare_for_save() # TODO: move to finalize - if not is_fsdp_model(self.model): - self.model.save_pretrained( + # save checkpoint + if self.accelerator.is_main_process: + processor = getattr(self, "processing_class", self.tokenizer) + save_checkpoint( output_dir, - save_compressed=self.model_args.save_compressed, - safe_serialization=self.args.save_safetensors, - ) - else: # FSDP model - save_pretrained_fsdp( model=self.model, - accelerator=self.accelerator, - output_dir=output_dir, + processor=processor, + save_safetensors=self.args.save_safetensors, save_compressed=self.model_args.save_compressed, - save_safetensors=self.metadata.get("save_safetensors", False), ) - - self.save_state() - processor = getattr(self, "processing_class", self.tokenizer) - if processor is not None: - processor.save_pretrained(output_dir) - - if not self.recipe: - return - - if self.accelerator.is_main_process: - # save recipe, will contain modifiers from the model's original recipe as - # well as those added from self.recipe - recipe_path = os.path.join(output_dir, RECIPE_FILE_NAME) - session = active_session() - recipe_yaml_str = session.get_serialized_recipe() - with open(recipe_path, "w") as fp: - fp.write(recipe_yaml_str) - - logger.info( - f"Saved LLM Compressor recipe with model state to {recipe_path}" - ) - self.accelerator.wait_for_everyone() if isinstance(self.model, KDModelWrapper): diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index d79d8cbbe..be5b977f6 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -45,12 +45,12 @@ get_session_model, initialize_recipe, parse_dtype, + save_checkpoint, ) from llmcompressor.recipe import Recipe, StageRunType from llmcompressor.transformers.finetune.runner import StageRunner from llmcompressor.transformers.finetune.trainer import Trainer from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( - modify_fsdp_model_save_pretrained, modify_save_pretrained, patch_tied_tensors_bug, ) @@ -418,7 +418,10 @@ def main( # wrap model.save_pretrained if is_fsdp_model(model): - modify_fsdp_model_save_pretrained(trainer, processor) + raise NotImplementedError( + "FSDP models are not supported in the current release but will be " + "suported in future releases of LLM Compressor" + ) else: modify_save_pretrained(model) @@ -455,16 +458,19 @@ def main( stage_runner.predict() # save if model was provided as a string or custom output_dir was set - if isinstance(model_args.model, str) or ( training_args.output_dir != TrainingArguments.__dataclass_fields__["output_dir"].default + and trainer.accelerator.is_main_process ): - model.save_pretrained( - training_args.output_dir, save_compressed=model_args.save_compressed + save_checkpoint( + save_path=training_args.output_dir, + model=model, + processor=processor, + save_safetensors=True, + save_compressed=model_args.save_compressed, ) - if processor is not None: - processor.save_pretrained(training_args.output_dir) + trainer.accelerator.wait_for_everyone() # Clean up the CompressionSession before exit if requested if recipe_args.clear_sparse_session: diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index ec9951f6a..6ba938ffc 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -16,6 +16,7 @@ ) from loguru import logger from safetensors.torch import storage_ptr +from transformers import PreTrainedModel from llmcompressor.core import active_session from llmcompressor.pytorch.model_load.helpers import copy_python_files_from_model_cache @@ -26,81 +27,11 @@ SparsityConfigMetadata, ) from llmcompressor.transformers.utils import RECIPE_FILE_NAME -from llmcompressor.typing import Processor -from llmcompressor.utils.fsdp.helpers import ( - find_and_move_state_dicts_to_cpu, - unwrap_and_export_model, -) - -__all__ = ["modify_save_pretrained", "modify_fsdp_model_save_pretrained"] - - -def modify_fsdp_model_save_pretrained(trainer, processor: Processor): - """ - Overrides a PreTrainedModel's save_pretrained() method with a wrapped version that - supports compression for fsdp model - """ - - def save_pretrained_compressed(save_pretrained_method): - if getattr(save_pretrained_method, "_overridden", False): - # `model.save_pretrained` has already been replaced, return. - return save_pretrained_method - - # Keep a weak reference to the model class and unbound save_pretrained - # method so we can call the original - original_save_pretrained = save_pretrained_method.__func__ - del save_pretrained_method - - @wraps(original_save_pretrained) - def save_pretrained_wrapper( - save_directory: str, - **kwargs, - ): - """ - Wrapper around PreTrainedModel.save_pretrained(), adds functionality for - saving models in a compressed format on disk. The compression format is - saved to the model's config file - :param save_directory: output directory to save model to - :param sparsity_config: optional sparsity config to compress model with, - if no config is provided it will be inferred from the model - :param quantization_format: optional compression format for quantized - models. If none is provided it will be inferred from the model - :param save_compressed: whether or not to compress the model on disk - :param skip_compression_stats: whether to skip the calculation of - compression statistics (such as global sparsity and sparsity structure) when - saving a model in dense format - :param kwargs: additional kwargs to pass on to model.save_pretrained - """ - try: - trainer.save_model(output_dir=save_directory, _is_oneshot=True) - except AssertionError: - # fallback to this in the case of quantization - unwrap_and_export_model( - model=trainer.model, - accelerator=trainer.accelerator, - output_dir=save_directory, - processor=processor, - ) - # only allow the main process move the state - # dicts to cpu - if trainer.accelerator.is_main_process: - # assuming quantization is the last step - # we no longer need the original model - # and can safely delete it to save memory - del trainer.model - find_and_move_state_dicts_to_cpu(save_directory) - - save_pretrained_wrapper._overriden = True - return save_pretrained_wrapper - - # wrap save_pretrained - trainer.model.save_pretrained = save_pretrained_compressed( - trainer.model.save_pretrained - ) +__all__ = ["modify_save_pretrained"] -def modify_save_pretrained(model: torch.nn.Module): +def modify_save_pretrained(model: PreTrainedModel): """ Overrides a PreTrainedModel's save_pretrained() method with a wrapped version that supports compression @@ -124,6 +55,7 @@ def save_pretrained_wrapper( sparsity_config: Optional[SparsityCompressionConfig] = None, quantization_format: Optional[str] = None, save_compressed: bool = True, + safe_serialization: bool = True, skip_compression_stats: bool = False, disable_sparse_compression: bool = False, **kwargs, @@ -189,19 +121,16 @@ def skip(*args, **kwargs): # make sure we're on the main process when saving if state_dict is not None and len(state_dict) > 0: compressed_state_dict = compressor.compress(model, state_dict) - - kwargs["safe_serialization"] = kwargs.get("safe_serialization", True) original_save_pretrained.__get__(model, model_class)( - save_directory, state_dict=compressed_state_dict, **kwargs + save_directory, + state_dict=compressed_state_dict, + safe_serialization=safe_serialization, + **kwargs, ) compressor.update_config(save_directory) - recipe_path = os.path.join(save_directory, RECIPE_FILE_NAME) - session = active_session() - - if (recipe_yaml_str := session.get_serialized_recipe()) is not None: - with open(recipe_path, "w") as fp: - fp.write(recipe_yaml_str) + # TODO: update existing recipe + update_and_save_recipe(model.name_or_path, save_directory) # copy python files from cache dir to save_path if any copy_python_files_from_model_cache(model, save_directory) @@ -320,3 +249,13 @@ def get_model_compressor( sparsity_config=sparsity_config, quantization_format=quantization_format, ) + + +def update_and_save_recipe(model_path: str, save_directory: str): + # TODO: update existing recipe + recipe_path = os.path.join(save_directory, RECIPE_FILE_NAME) + session = active_session() + + if (recipe_yaml_str := session.get_serialized_recipe()) is not None: + with open(recipe_path, "w") as fp: + fp.write(recipe_yaml_str) From a83b0aa1cde223567fc6c308fe1db582fcc63410 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 18 Feb 2025 15:47:59 -0500 Subject: [PATCH 13/22] remove broken import Signed-off-by: Kyle Sayers --- src/llmcompressor/utils/fsdp/helpers.py | 32 ------------------------- 1 file changed, 32 deletions(-) diff --git a/src/llmcompressor/utils/fsdp/helpers.py b/src/llmcompressor/utils/fsdp/helpers.py index 3a3248fa5..53fc04ca8 100644 --- a/src/llmcompressor/utils/fsdp/helpers.py +++ b/src/llmcompressor/utils/fsdp/helpers.py @@ -17,15 +17,11 @@ from torch.nn import Module from llmcompressor.core.state import State -from llmcompressor.pytorch.model_load.helpers import save_model_and_recipe -from llmcompressor.typing import Processor -from llmcompressor.utils.pytorch import set_layer __all__ = [ "is_fsdp_model", "maybe_get_wrapped", "set_wrapped_model", - "unwrap_and_export_model", "save_pretrained_fsdp", "get_fsdp_parent", "find_and_move_state_dicts_to_cpu", @@ -72,34 +68,6 @@ def set_wrapped_model(state: State, wrapped_model: Module): state.model = wrapped_model -def unwrap_and_export_model(model, accelerator, output_dir: str, processor: Processor): - """ - Recursively unwraps an FSDP model, then saves the unwrapped model and the - currently active recipe to disk - - :param model: model to unwrap - :param accelerator: Accelerator instance used to perform unwrapping - :param output_dir: where to save output model - :param processor: processor used by the model - """ - full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - with FullyShardedDataParallel.state_dict_type( - model, - StateDictType.FULL_STATE_DICT, - full_state_dict_config, - ): - unwrapped_model = accelerator.unwrap_model(model) - for name, module in unwrapped_model.named_modules(): - if isinstance(module, FullyShardedDataParallel): - set_layer(name, accelerator.unwrap_model(module), unwrapped_model) - - save_model_and_recipe( - model=unwrapped_model, - save_path=output_dir, - processor=processor, - ) - - def find_and_move_state_dicts_to_cpu(output_dir: str): """ Looks for state dicts in the output directory and overwrites them From b9f0bd14910b7c3188ee1e1847ec0f4e51f74046 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 18 Feb 2025 15:50:59 -0500 Subject: [PATCH 14/22] add back def Signed-off-by: Kyle Sayers --- src/llmcompressor/pytorch/model_load/helpers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llmcompressor/pytorch/model_load/helpers.py b/src/llmcompressor/pytorch/model_load/helpers.py index 0f3c31df5..e2e1a91b7 100644 --- a/src/llmcompressor/pytorch/model_load/helpers.py +++ b/src/llmcompressor/pytorch/model_load/helpers.py @@ -14,6 +14,7 @@ COMPLETED_STAGES_FILENAME = "completed_stages.json" __all__ = [ + "initialize_recipe", "copy_python_files_from_model_cache", "fallback_to_cpu", "parse_dtype", From 0a2642bbd9df81076e4f59b568debdfcde32571f Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 18 Feb 2025 15:59:46 -0500 Subject: [PATCH 15/22] save state Signed-off-by: Kyle Sayers --- src/llmcompressor/transformers/finetune/session_mixin.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llmcompressor/transformers/finetune/session_mixin.py b/src/llmcompressor/transformers/finetune/session_mixin.py index bd8382b64..1f60da25c 100644 --- a/src/llmcompressor/transformers/finetune/session_mixin.py +++ b/src/llmcompressor/transformers/finetune/session_mixin.py @@ -482,6 +482,7 @@ def save_model(self, output_dir: str, _internal_call=False, _is_oneshot=False): self.model.prepare_for_save() # TODO: move to finalize # save checkpoint + self.save_state() if self.accelerator.is_main_process: processor = getattr(self, "processing_class", self.tokenizer) save_checkpoint( From 60371ef35538d7583f494b6dc8af7d72ca6d7a58 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 18 Feb 2025 16:29:38 -0500 Subject: [PATCH 16/22] remove verbose messages Signed-off-by: Kyle Sayers --- src/llmcompressor/recipe/recipe.py | 52 ------------------- .../transformers/utils/helpers.py | 13 +---- 2 files changed, 1 insertion(+), 64 deletions(-) diff --git a/src/llmcompressor/recipe/recipe.py b/src/llmcompressor/recipe/recipe.py index 9e2398944..f48c4a568 100644 --- a/src/llmcompressor/recipe/recipe.py +++ b/src/llmcompressor/recipe/recipe.py @@ -602,58 +602,6 @@ class RecipeTuple: target_stages: List[str] override_args: Dict[str, Any] - @staticmethod - def from_inputs( - cls, - recipe: Optional[RecipeInput] = None, - recipe_stage: Optional[RecipeStageInput] = None, - recipe_args: Optional[RecipeArgsInput] = None, - ) -> List["RecipeTuple"]: - if recipe is None or recipe == []: - return [] - - # prepare recipe - if isinstance(recipe, Modifier) or ( - isinstance(recipe, list) - and all(isinstance(mod, Modifier) for mod in recipe) - ): - recipe = Recipe.create_instance(recipe) - - if not isinstance(recipe, list): - recipe = [recipe] - - recipe = [ - Recipe.create_instance(rec) if isinstance(rec, str) else rec - for rec in recipe - ] - - # prepare stage - if recipe_stage is None: - recipe_stage = [None] * len(recipe) - else: - if not isinstance(recipe_stage, list): - recipe_stage = [[recipe_stage]] * len(recipe) - if not isinstance(recipe_stage[0], list): - recipe_stage = [recipe_stage] * len(recipe) - - # prepare args - if recipe_args is None: - recipe_args = [{}] * len(recipe) - elif not isinstance(recipe_args, list): - recipe_args = [recipe_args] * len(recipe) - - # validation - if len(recipe) != len(recipe_stage) or len(recipe) != len(recipe_args): - raise ValueError( - "recipe, recipe_stage, and recipe_args must be the same length" - ) - - # create tuples - return [ - cls(rec, stage, args) - for rec, stage, args in zip(recipe, recipe_stage, recipe_args) - ] - def _load_json_or_yaml_string(content: str) -> Dict[str, Any]: # try loading as json first, then yaml diff --git a/src/llmcompressor/transformers/utils/helpers.py b/src/llmcompressor/transformers/utils/helpers.py index 349ff8f09..d77203b0b 100644 --- a/src/llmcompressor/transformers/utils/helpers.py +++ b/src/llmcompressor/transformers/utils/helpers.py @@ -140,20 +140,9 @@ def recipe_from_huggingface_model_id( return None try: - logger.info( - "Attempting to download a recipe ", - f"{hf_stub} " f"from {HUGGINGFACE_CO_URL_HOME}", - ) recipe = hf_hub_download(repo_id=hf_stub, filename=recipe_file_name) logger.info(f"Found recipe: {recipe_file_name} for model ID: {hf_stub}.") - except Exception as e: - logger.error( - ( - f"Unable to find recipe {recipe_file_name} " - f"for model ID: {hf_stub}: {e}." - "Skipping recipe resolution." - ) - ) + except Exception: # TODO: narrow acceptable exceptions recipe = None return recipe From bf9a8cdf698b5fa44ca3b533be3e2be505c0d063 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 18 Feb 2025 18:02:44 -0500 Subject: [PATCH 17/22] fix double initialization Signed-off-by: Kyle Sayers --- src/llmcompressor/core/lifecycle.py | 5 ++++- src/llmcompressor/transformers/utils/helpers.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/core/lifecycle.py b/src/llmcompressor/core/lifecycle.py index 3e5542528..e7274b21a 100644 --- a/src/llmcompressor/core/lifecycle.py +++ b/src/llmcompressor/core/lifecycle.py @@ -85,8 +85,11 @@ def initialize( :return: List of data returned from initialization of modifiers :rtype: List[Any] """ - logger.debug("Initializing compression lifecycle") self.state.update(**kwargs) + if self.initialized_: # TODO: do not initialize twice + return + + logger.debug("Initializing compression lifecycle") self.recipe_container.append(recipe, recipe_stage, recipe_args) self.modifiers = self.recipe_container.get_modifiers() self._set_model_layer_prefix() diff --git a/src/llmcompressor/transformers/utils/helpers.py b/src/llmcompressor/transformers/utils/helpers.py index d77203b0b..66cc11189 100644 --- a/src/llmcompressor/transformers/utils/helpers.py +++ b/src/llmcompressor/transformers/utils/helpers.py @@ -109,7 +109,7 @@ def infer_recipe_from_model_path(model_path: Union[str, Path]) -> Optional[str]: recipe = recipe_from_huggingface_model_id(hf_stub=model_path) if recipe is None: - logger.info("Failed to infer the recipe from the model_path") + logger.debug("Failed to infer the recipe from the model_path") return recipe From 3d64d578f817e1f9b74c4e422c9f702800d809e8 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 25 Feb 2025 11:42:59 -0500 Subject: [PATCH 18/22] rename function Signed-off-by: Kyle Sayers --- .../modifiers/quantization/gptq/base.py | 4 +- .../qwen2_moe/configuration_qwen2_moe.py | 246 +++ .../tracing/qwen2_moe/modeling_qwen2_moe.py | 1659 +++++++++++++++++ .../pruning/sparsegpt/test_pytorch.py | 4 +- 4 files changed, 1909 insertions(+), 4 deletions(-) create mode 100644 src/llmcompressor/transformers/tracing/qwen2_moe/configuration_qwen2_moe.py create mode 100644 src/llmcompressor/transformers/tracing/qwen2_moe/modeling_qwen2_moe.py diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index b68be2c59..525ba1301 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -140,7 +140,7 @@ def validate_sequential_update(cls, value: bool) -> bool: return True - def _maybe_build_quant_modifier(self, model: torch.nn.Module): + def _check_build_quant_modifier(self, model: torch.nn.Module): """ Check the model's quantization state matches that expected by this modifier, adding a default quantization scheme if needed @@ -197,7 +197,7 @@ def on_initialize(self, state: State, **kwargs) -> bool: :param state: session state storing input model and calibration data """ # build quantization modifier - self._maybe_build_quant_modifier(state.model) + self._check_build_quant_modifier(state.model) if self._quantization_modifier: self._quantization_modifier.initialize(state, **kwargs) diff --git a/src/llmcompressor/transformers/tracing/qwen2_moe/configuration_qwen2_moe.py b/src/llmcompressor/transformers/tracing/qwen2_moe/configuration_qwen2_moe.py new file mode 100644 index 000000000..b46fac4c8 --- /dev/null +++ b/src/llmcompressor/transformers/tracing/qwen2_moe/configuration_qwen2_moe.py @@ -0,0 +1,246 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen2MoE model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Qwen2MoeConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2MoeModel`]. It is used to instantiate a + Qwen2MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen1.5-MoE-A2.7B" [Qwen/Qwen1.5-MoE-A2.7B"](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B"). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2MoeModel`] + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 5632): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 16): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + decoder_sparse_step (`int`, *optional*, defaults to 1): + The frequency of the MoE layer. + moe_intermediate_size (`int`, *optional*, defaults to 1408): + Intermediate size of the routed expert. + shared_expert_intermediate_size (`int`, *optional*, defaults to 5632): + Intermediate size of the shared expert. + num_experts_per_tok (`int`, *optional*, defaults to 4): + Number of selected experts. + num_experts (`int`, *optional*, defaults to 60): + Number of routed experts. + norm_topk_prob (`bool`, *optional*, defaults to `False`): + Whether to normalize the topk probabilities. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabeling this will also + allow the model to output the auxiliary loss, including load balancing loss and router z-loss. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + mlp_only_layers (`List[int]`, *optional*, defaults to `[]`): + Indicate which layers use Qwen2MoeMLP rather than Qwen2MoeSparseMoeBlock + The list contains layer index, from 0 to num_layers-1 if we have num_layers layers + If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity. + + ```python + >>> from transformers import Qwen2MoeModel, Qwen2MoeConfig + + >>> # Initializing a Qwen2MoE style configuration + >>> configuration = Qwen2MoeConfig() + + >>> # Initializing a model from the Qwen1.5-MoE-A2.7B" style configuration + >>> model = Qwen2MoeModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_moe" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `Qwen2Moe` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=151936, + hidden_size=2048, + intermediate_size=5632, + num_hidden_layers=24, + num_attention_heads=16, + num_key_value_heads=16, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + decoder_sparse_step=1, + moe_intermediate_size=1408, + shared_expert_intermediate_size=5632, + num_experts_per_tok=4, + num_experts=60, + norm_topk_prob=False, + output_router_logits=False, + router_aux_loss_coef=0.001, + mlp_only_layers=None, + # CALIBRATION: add option to use inference-time activations + moe_calibrate_experts=False, + moe_eval_mode=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window if use_sliding_window else None + self.max_window_layers = max_window_layers + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + # MoE arguments + self.decoder_sparse_step = decoder_sparse_step + self.moe_intermediate_size = moe_intermediate_size + self.shared_expert_intermediate_size = shared_expert_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.norm_topk_prob = norm_topk_prob + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers + self.moe_calibrate_experts = moe_calibrate_experts + self.moe_eval_mode = moe_eval_mode + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["Qwen2MoeConfig"] \ No newline at end of file diff --git a/src/llmcompressor/transformers/tracing/qwen2_moe/modeling_qwen2_moe.py b/src/llmcompressor/transformers/tracing/qwen2_moe/modeling_qwen2_moe.py new file mode 100644 index 000000000..1f8937743 --- /dev/null +++ b/src/llmcompressor/transformers/tracing/qwen2_moe/modeling_qwen2_moe.py @@ -0,0 +1,1659 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2MoE model.""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from ...utils.deprecation import deprecate_kwarg +from .configuration_qwen2_moe import Qwen2MoeConfig + + +if is_flash_attn_2_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-57B-A14B" +_CONFIG_FOR_DOC = "Qwen2MoeConfig" + + +# Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2Moe +class Qwen2MoeRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2MoeRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2Moe +class Qwen2MoeRotaryEmbedding(nn.Module): + def __init__(self, config: Qwen2MoeConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Modified from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2Moe +class Qwen2MoeMLP(nn.Module): + def __init__(self, config, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# copied from transformers.models.qwen2.modeling_qwen2.Qwen2Attention with Qwen2->Qwen2Moe +# no longer copied after attention refactors +class Qwen2MoeAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Qwen2MoeConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = Qwen2MoeRotaryEmbedding(config=self.config) + + # Ignore copy + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# NO LONGER EXIST Copied from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2 with Qwen2->Qwen2Moe +# TODO cyril: modular +class Qwen2MoeFlashAttention2(Qwen2MoeAttention): + """ + Qwen2Moe flash attention module, following Qwen2Moe attention module. This module inherits from `Qwen2MoeAttention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + else: + sliding_window = None + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=sliding_window, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# NO LONGER EXIST Copied from transformers.models.qwen2.modeling_qwen2.Qwen2SdpaAttention with Qwen2->Qwen2Moe +# TODO cyril: modular +class Qwen2MoeSdpaAttention(Qwen2MoeAttention): + """ + Qwen2Moe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Qwen2MoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Qwen2MoeAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Qwen2MoeModel is using Qwen2MoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +QWEN2MOE_ATTENTION_CLASSES = { + "eager": Qwen2MoeAttention, + "flash_attention_2": Qwen2MoeFlashAttention2, + "sdpa": Qwen2MoeSdpaAttention, +} + + +class Qwen2MoeSparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + + # gating + self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.experts = nn.ModuleList( + [Qwen2MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)] + ) + + self.shared_expert = Qwen2MoeMLP(config, intermediate_size=config.shared_expert_intermediate_size) + self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + if self.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + + shared_expert_output = self.shared_expert(hidden_states) + shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output + + final_hidden_states = final_hidden_states + shared_expert_output + + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +class Qwen2MoeDecoderLayer(nn.Module): + def __init__(self, config: Qwen2MoeConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = QWEN2MOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + if (layer_idx not in config.mlp_only_layers) and ( + config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 + ): + self.mlp = Qwen2MoeSparseMoeBlock(config) + else: + self.mlp = Qwen2MoeMLP(config, intermediate_size=config.intermediate_size) + + self.input_layernorm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, + and should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states = self.mlp(hidden_states) + if isinstance(hidden_states, tuple): + hidden_states, router_logits = hidden_states + else: + router_logits = None + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +QWEN2MOE_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Qwen2MoeConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.", + QWEN2MOE_START_DOCSTRING, +) +class Qwen2MoePreTrainedModel(PreTrainedModel): + config_class = Qwen2MoeConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2MoeDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +QWEN2MOE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.", + QWEN2MOE_START_DOCSTRING, +) +class Qwen2MoeModel(Qwen2MoePreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2MoeDecoderLayer`] + + Args: + config: Qwen2MoeConfig + """ + + def __init__(self, config: Qwen2MoeConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen2MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2MoeRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, MoeModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits and layer_outputs[-1] is not None: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask with Phi3->Qwen2Moe + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Qwen2Moe. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Qwen2Moe + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Qwen2MoeConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Qwen2MoeConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Qwen2MoeModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_experts + self.num_experts_per_tok = config.num_experts_per_tok + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **loss_kwargs, + ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen2MoeForCausalLM + + >>> model = Qwen2MoeForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +@add_start_docstrings( + """ + The Qwen2MoE Model transformer with a sequence classification head on top (linear layer). + + [`Qwen2MoeForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + QWEN2MOE_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Qwen2Moe, LLAMA->QWEN2MOE +class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen2MoeModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Qwen2MoE Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + QWEN2MOE_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2Moe, LLAMA->QWEN2MOE +class Qwen2MoeForTokenClassification(Qwen2MoePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen2MoeModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Qwen2MoE Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + QWEN2MOE_START_DOCSTRING, +) +# Copied from transformers.models.mistral.modeling_mistral.MistralForQuestionAnswering with Mistral->Qwen2Moe, MISTRAL->QWEN2MOE +class Qwen2MoeForQuestionAnswering(Qwen2MoePreTrainedModel): + base_model_prefix = "model" + + def __init__(self, config): + super().__init__(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + self.model = Qwen2MoeModel(config) # diff with Llama: transformer->model + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "Qwen2MoeForCausalLM", + "Qwen2MoeForQuestionAnswering", + "Qwen2MoeModel", + "Qwen2MoePreTrainedModel", + "Qwen2MoeForSequenceClassification", + "Qwen2MoeForTokenClassification", +] \ No newline at end of file diff --git a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py index 6d50bfff6..0c5ad534d 100644 --- a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py @@ -71,7 +71,7 @@ def test_create_default_quant_modifier(self): assert modifier._quantization_modifier is None testing_harness = LifecyleTestingHarness(model=LinearNet()) - modifier._maybe_build_quant_modifier(testing_harness.get_state().model) + modifier._check_build_quant_modifier(testing_harness.get_state().model) assert modifier.quantize assert isinstance(modifier._quantization_modifier, QuantizationModifier) modifier._quantization_modifier.create_init_config() @@ -138,7 +138,7 @@ def test_set_quant_in_gptq(self): assert modifier._quantization_modifier is None testing_harness = LifecyleTestingHarness(model=LinearNet()) - modifier._maybe_build_quant_modifier(testing_harness.get_state().model) + modifier._check_build_quant_modifier(testing_harness.get_state().model) assert modifier.quantize self.assertIsInstance(modifier._quantization_modifier, QuantizationModifier) From 55984c4adf0110158ef606638981f319e29c77e3 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 25 Feb 2025 11:44:55 -0500 Subject: [PATCH 19/22] remove accidentally added files Signed-off-by: Kyle Sayers --- .../qwen2_moe/configuration_qwen2_moe.py | 246 --- .../tracing/qwen2_moe/modeling_qwen2_moe.py | 1659 ----------------- 2 files changed, 1905 deletions(-) delete mode 100644 src/llmcompressor/transformers/tracing/qwen2_moe/configuration_qwen2_moe.py delete mode 100644 src/llmcompressor/transformers/tracing/qwen2_moe/modeling_qwen2_moe.py diff --git a/src/llmcompressor/transformers/tracing/qwen2_moe/configuration_qwen2_moe.py b/src/llmcompressor/transformers/tracing/qwen2_moe/configuration_qwen2_moe.py deleted file mode 100644 index b46fac4c8..000000000 --- a/src/llmcompressor/transformers/tracing/qwen2_moe/configuration_qwen2_moe.py +++ /dev/null @@ -1,246 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Qwen2MoE model configuration""" - -from ...configuration_utils import PretrainedConfig -from ...modeling_rope_utils import rope_config_validation -from ...utils import logging - - -logger = logging.get_logger(__name__) - - -class Qwen2MoeConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`Qwen2MoeModel`]. It is used to instantiate a - Qwen2MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of - Qwen1.5-MoE-A2.7B" [Qwen/Qwen1.5-MoE-A2.7B"](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B"). - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 151936): - Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`Qwen2MoeModel`] - hidden_size (`int`, *optional*, defaults to 2048): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 5632): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 24): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 16): - Number of attention heads for each attention layer in the Transformer encoder. - num_key_value_heads (`int`, *optional*, defaults to 16): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 32768): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether the model's input and output word embeddings should be tied. - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type - and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value - accordingly. - Expected contents: - `rope_type` (`str`): - The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', - 'llama3'], with 'default' being the original RoPE implementation. - `factor` (`float`, *optional*): - Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In - most scaling types, a `factor` of x will enable the model to handle sequences of length x * - original maximum pre-trained length. - `original_max_position_embeddings` (`int`, *optional*): - Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during - pretraining. - `attention_factor` (`float`, *optional*): - Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention - computation. If unspecified, it defaults to value recommended by the implementation, using the - `factor` field to infer the suggested value. - `beta_fast` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear - ramp function. If unspecified, it defaults to 32. - `beta_slow` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear - ramp function. If unspecified, it defaults to 1. - `short_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to short contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `long_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to long contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `low_freq_factor` (`float`, *optional*): - Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE - `high_freq_factor` (`float`, *optional*): - Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE - use_sliding_window (`bool`, *optional*, defaults to `False`): - Whether to use sliding window attention. - sliding_window (`int`, *optional*, defaults to 4096): - Sliding window attention (SWA) window size. If not specified, will default to `4096`. - max_window_layers (`int`, *optional*, defaults to 28): - The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - decoder_sparse_step (`int`, *optional*, defaults to 1): - The frequency of the MoE layer. - moe_intermediate_size (`int`, *optional*, defaults to 1408): - Intermediate size of the routed expert. - shared_expert_intermediate_size (`int`, *optional*, defaults to 5632): - Intermediate size of the shared expert. - num_experts_per_tok (`int`, *optional*, defaults to 4): - Number of selected experts. - num_experts (`int`, *optional*, defaults to 60): - Number of routed experts. - norm_topk_prob (`bool`, *optional*, defaults to `False`): - Whether to normalize the topk probabilities. - output_router_logits (`bool`, *optional*, defaults to `False`): - Whether or not the router logits should be returned by the model. Enabeling this will also - allow the model to output the auxiliary loss, including load balancing loss and router z-loss. - router_aux_loss_coef (`float`, *optional*, defaults to 0.001): - The aux loss factor for the total loss. - mlp_only_layers (`List[int]`, *optional*, defaults to `[]`): - Indicate which layers use Qwen2MoeMLP rather than Qwen2MoeSparseMoeBlock - The list contains layer index, from 0 to num_layers-1 if we have num_layers layers - If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity. - - ```python - >>> from transformers import Qwen2MoeModel, Qwen2MoeConfig - - >>> # Initializing a Qwen2MoE style configuration - >>> configuration = Qwen2MoeConfig() - - >>> # Initializing a model from the Qwen1.5-MoE-A2.7B" style configuration - >>> model = Qwen2MoeModel(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "qwen2_moe" - keys_to_ignore_at_inference = ["past_key_values"] - - # Default tensor parallel plan for base model `Qwen2Moe` - base_model_tp_plan = { - "layers.*.self_attn.q_proj": "colwise", - "layers.*.self_attn.k_proj": "colwise", - "layers.*.self_attn.v_proj": "colwise", - "layers.*.self_attn.o_proj": "rowwise", - "layers.*.mlp.gate_proj": "colwise", - "layers.*.mlp.up_proj": "colwise", - "layers.*.mlp.down_proj": "rowwise", - } - base_model_pp_plan = { - "embed_tokens": (["input_ids"], ["inputs_embeds"]), - "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), - "norm": (["hidden_states"], ["hidden_states"]), - } - - def __init__( - self, - vocab_size=151936, - hidden_size=2048, - intermediate_size=5632, - num_hidden_layers=24, - num_attention_heads=16, - num_key_value_heads=16, - hidden_act="silu", - max_position_embeddings=32768, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - tie_word_embeddings=False, - rope_theta=10000.0, - rope_scaling=None, - use_sliding_window=False, - sliding_window=4096, - max_window_layers=28, - attention_dropout=0.0, - decoder_sparse_step=1, - moe_intermediate_size=1408, - shared_expert_intermediate_size=5632, - num_experts_per_tok=4, - num_experts=60, - norm_topk_prob=False, - output_router_logits=False, - router_aux_loss_coef=0.001, - mlp_only_layers=None, - # CALIBRATION: add option to use inference-time activations - moe_calibrate_experts=False, - moe_eval_mode=False, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.use_sliding_window = use_sliding_window - self.sliding_window = sliding_window if use_sliding_window else None - self.max_window_layers = max_window_layers - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.rope_theta = rope_theta - self.rope_scaling = rope_scaling - self.attention_dropout = attention_dropout - # Validate the correctness of rotary position embeddings parameters - # BC: if there is a 'type' field, move it to 'rope_type'. - if self.rope_scaling is not None and "type" in self.rope_scaling: - self.rope_scaling["rope_type"] = self.rope_scaling["type"] - rope_config_validation(self) - - # MoE arguments - self.decoder_sparse_step = decoder_sparse_step - self.moe_intermediate_size = moe_intermediate_size - self.shared_expert_intermediate_size = shared_expert_intermediate_size - self.num_experts_per_tok = num_experts_per_tok - self.num_experts = num_experts - self.norm_topk_prob = norm_topk_prob - self.output_router_logits = output_router_logits - self.router_aux_loss_coef = router_aux_loss_coef - self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers - self.moe_calibrate_experts = moe_calibrate_experts - self.moe_eval_mode = moe_eval_mode - - super().__init__( - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - -__all__ = ["Qwen2MoeConfig"] \ No newline at end of file diff --git a/src/llmcompressor/transformers/tracing/qwen2_moe/modeling_qwen2_moe.py b/src/llmcompressor/transformers/tracing/qwen2_moe/modeling_qwen2_moe.py deleted file mode 100644 index 1f8937743..000000000 --- a/src/llmcompressor/transformers/tracing/qwen2_moe/modeling_qwen2_moe.py +++ /dev/null @@ -1,1659 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Qwen2MoE model.""" - -import math -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn - -from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache -from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_outputs import ( - MoeCausalLMOutputWithPast, - MoeModelOutputWithPast, - QuestionAnsweringModelOutput, - SequenceClassifierOutputWithPast, - TokenClassifierOutput, -) -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import PreTrainedModel -from ...utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, - logging, - replace_return_docstrings, -) -from ...utils.deprecation import deprecate_kwarg -from .configuration_qwen2_moe import Qwen2MoeConfig - - -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-57B-A14B" -_CONFIG_FOR_DOC = "Qwen2MoeConfig" - - -# Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func -def load_balancing_loss_func( - gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None], - num_experts: Optional[int] = None, - top_k=2, - attention_mask: Optional[torch.Tensor] = None, -) -> Union[torch.Tensor, int]: - r""" - Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. - - See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss - function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between - experts is too unbalanced. - - Args: - gate_logits: - Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of - shape [batch_size X sequence_length, num_experts]. - num_experts: - Number of experts - top_k: - The number of experts to route per-token, can be also interpreted as the `top-k` routing - parameter. - attention_mask (`torch.Tensor`, *optional*): - The attention_mask used in forward function - shape [batch_size X sequence_length] if not None. - - Returns: - The auxiliary loss. - """ - if gate_logits is None or not isinstance(gate_logits, tuple): - return 0 - - if isinstance(gate_logits, tuple): - compute_device = gate_logits[0].device - concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) - - routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) - - _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) - - expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) - - if attention_mask is None: - # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.mean(expert_mask.float(), dim=0) - - # Compute the average probability of routing to these experts - router_prob_per_expert = torch.mean(routing_weights, dim=0) - else: - batch_size, sequence_length = attention_mask.shape - num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) - - # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask - expert_attention_mask = ( - attention_mask[None, :, :, None, None] - .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) - .reshape(-1, top_k, num_experts) - .to(compute_device) - ) - - # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( - expert_attention_mask, dim=0 - ) - - # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert - router_per_expert_attention_mask = ( - attention_mask[None, :, :, None] - .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) - .reshape(-1, num_experts) - .to(compute_device) - ) - - # Compute the average probability of routing to these experts - router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( - router_per_expert_attention_mask, dim=0 - ) - - overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) - return overall_loss * num_experts - - -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2Moe -class Qwen2MoeRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Qwen2MoeRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2Moe -class Qwen2MoeRotaryEmbedding(nn.Module): - def __init__(self, config: Qwen2MoeConfig, device=None): - super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - # This .to() is needed if the model has been moved to a device after being initialized (because - # the buffer is automatically moved, but not the original copy) - self.original_inv_freq = self.original_inv_freq.to(device) - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -# Modified from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2Moe -class Qwen2MoeMLP(nn.Module): - def __init__(self, config, intermediate_size=None): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -# copied from transformers.models.qwen2.modeling_qwen2.Qwen2Attention with Qwen2->Qwen2Moe -# no longer copied after attention refactors -class Qwen2MoeAttention(nn.Module): - """ - Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer - and "Generating Long Sequences with Sparse Transformers". - """ - - def __init__(self, config: Qwen2MoeConfig, layer_idx: Optional[int] = None): - super().__init__() - self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " - "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - self.attention_dropout = config.attention_dropout - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - self.rotary_emb = Qwen2MoeRotaryEmbedding(config=self.config) - - # Ignore copy - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -# NO LONGER EXIST Copied from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2 with Qwen2->Qwen2Moe -# TODO cyril: modular -class Qwen2MoeFlashAttention2(Qwen2MoeAttention): - """ - Qwen2Moe flash attention module, following Qwen2Moe attention module. This module inherits from `Qwen2MoeAttention` - as the weights of the module stays untouched. The only required change would be on the forward pass - where it needs to correctly call the public API of flash attention and deal with padding tokens - in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom - config.max_window_layers layers. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ): - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - if ( - self.config.use_sliding_window - and getattr(self.config, "sliding_window", None) is not None - and self.layer_idx >= self.config.max_window_layers - ): - sliding_window = self.config.sliding_window - else: - sliding_window = None - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=sliding_window, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -# NO LONGER EXIST Copied from transformers.models.qwen2.modeling_qwen2.Qwen2SdpaAttention with Qwen2->Qwen2Moe -# TODO cyril: modular -class Qwen2MoeSdpaAttention(Qwen2MoeAttention): - """ - Qwen2Moe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Qwen2MoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Qwen2MoeAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Qwen2MoeModel is using Qwen2MoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -QWEN2MOE_ATTENTION_CLASSES = { - "eager": Qwen2MoeAttention, - "flash_attention_2": Qwen2MoeFlashAttention2, - "sdpa": Qwen2MoeSdpaAttention, -} - - -class Qwen2MoeSparseMoeBlock(nn.Module): - def __init__(self, config): - super().__init__() - self.num_experts = config.num_experts - self.top_k = config.num_experts_per_tok - self.norm_topk_prob = config.norm_topk_prob - - # gating - self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) - self.experts = nn.ModuleList( - [Qwen2MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)] - ) - - self.shared_expert = Qwen2MoeMLP(config, intermediate_size=config.shared_expert_intermediate_size) - self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """ """ - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - # router_logits: (batch * sequence_length, n_experts) - router_logits = self.gate(hidden_states) - - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) - if self.norm_topk_prob: - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - # we cast back to the input dtype - routing_weights = routing_weights.to(hidden_states.dtype) - - final_hidden_states = torch.zeros( - (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device - ) - - # One hot encode the selected experts to create an expert mask - # this will be used to easily index which expert is going to be sollicitated - expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) - - # Loop over all available experts in the model and perform the computation on each expert - for expert_idx in range(self.num_experts): - expert_layer = self.experts[expert_idx] - idx, top_x = torch.where(expert_mask[expert_idx]) - - # Index the correct hidden states and compute the expert hidden state for - # the current expert. We need to make sure to multiply the output hidden - # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) - current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] - - # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) - - shared_expert_output = self.shared_expert(hidden_states) - shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output - - final_hidden_states = final_hidden_states + shared_expert_output - - final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) - return final_hidden_states, router_logits - - -class Qwen2MoeDecoderLayer(nn.Module): - def __init__(self, config: Qwen2MoeConfig, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - self.self_attn = QWEN2MOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) - - if (layer_idx not in config.mlp_only_layers) and ( - config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 - ): - self.mlp = Qwen2MoeSparseMoeBlock(config) - else: - self.mlp = Qwen2MoeMLP(config, intermediate_size=config.intermediate_size) - - self.input_layernorm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - output_router_logits: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_router_logits (`bool`, *optional*): - Whether or not to return the logits of all the routers. They are useful for computing the router loss, - and should not be returned during inference. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - - hidden_states = self.mlp(hidden_states) - if isinstance(hidden_states, tuple): - hidden_states, router_logits = hidden_states - else: - router_logits = None - - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - if output_router_logits: - outputs += (router_logits,) - - return outputs - - -QWEN2MOE_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`Qwen2MoeConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.", - QWEN2MOE_START_DOCSTRING, -) -class Qwen2MoePreTrainedModel(PreTrainedModel): - config_class = Qwen2MoeConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Qwen2MoeDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - -QWEN2MOE_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - - Two formats are allowed: - - a [`~cache_utils.Cache`] instance, see our - [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. - - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - output_router_logits (`bool`, *optional*): - Whether or not to return the logits of all the routers. They are useful for computing the router loss, and - should not be returned during inference. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, - this tensor is not affected by padding. It is used to update the cache in the correct position and to infer - the complete sequence length. -""" - - -@add_start_docstrings( - "The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.", - QWEN2MOE_START_DOCSTRING, -) -class Qwen2MoeModel(Qwen2MoePreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2MoeDecoderLayer`] - - Args: - config: Qwen2MoeConfig - """ - - def __init__(self, config: Qwen2MoeConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [Qwen2MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self._attn_implementation = config._attn_implementation - self.norm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = Qwen2MoeRotaryEmbedding(config=config) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, MoeModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_router_logits = ( - output_router_logits if output_router_logits is not None else self.config.output_router_logits - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) - - hidden_states = inputs_embeds - - # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_router_logits = () if output_router_logits else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - output_router_logits, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if output_router_logits and layer_outputs[-1] is not None: - all_router_logits += (layer_outputs[-1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] - if v is not None - ) - return MoeModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - router_logits=all_router_logits, - ) - - # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask with Phi3->Qwen2Moe - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and past_key_values is not None: - is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Qwen2Moe. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not (using_static_cache or using_sliding_window_cache) - and not output_attentions - ): - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, - is_training=self.training, - ): - return None - - dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - # SlidingWindowCache or StaticCache - if using_sliding_window_cache or using_static_cache: - target_length = past_key_values.get_max_cache_shape() - # DynamicCache or no cache - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - config=self.config, - past_key_values=past_key_values, - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Qwen2Moe - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - cache_position: torch.Tensor, - batch_size: int, - config: Qwen2MoeConfig, - past_key_values: Cache, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - config (`Qwen2MoeConfig`): - The model's configuration class - past_key_values (`Cache`): - The cache class that is being used currently to generate - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device - ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - if config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also - # the check is needed to verify is current checkpoint was trained with sliding window or not - if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( - cache_position.reshape(-1, 1) - config.sliding_window - ) - diagonal_attend_mask.bitwise_or_(sliding_attend_mask) - causal_mask *= diagonal_attend_mask - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - return causal_mask - - -class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - - def __init__(self, config): - super().__init__(config) - self.model = Qwen2MoeModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - self.router_aux_loss_coef = config.router_aux_loss_coef - self.num_experts = config.num_experts - self.num_experts_per_tok = config.num_experts_per_tok - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") - @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **loss_kwargs, - ) -> Union[Tuple, MoeCausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, Qwen2MoeForCausalLM - - >>> model = Qwen2MoeForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_router_logits = ( - output_router_logits if output_router_logits is not None else self.config.output_router_logits - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_router_logits=output_router_logits, - return_dict=return_dict, - cache_position=cache_position, - ) - - hidden_states = outputs[0] - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - loss = None - if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) - - aux_loss = None - if output_router_logits: - aux_loss = load_balancing_loss_func( - outputs.router_logits if return_dict else outputs[-1], - self.num_experts, - self.num_experts_per_tok, - attention_mask, - ) - if labels is not None: - loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device - - if not return_dict: - output = (logits,) + outputs[1:] - if output_router_logits: - output = (aux_loss,) + output - return (loss,) + output if loss is not None else output - - return MoeCausalLMOutputWithPast( - loss=loss, - aux_loss=aux_loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - router_logits=outputs.router_logits, - ) - - -@add_start_docstrings( - """ - The Qwen2MoE Model transformer with a sequence classification head on top (linear layer). - - [`Qwen2MoeForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - QWEN2MOE_START_DOCSTRING, -) -# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Qwen2Moe, LLAMA->QWEN2MOE -class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = Qwen2MoeModel(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - last_non_pad_token = -1 - elif input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - -@add_start_docstrings( - """ - The Qwen2MoE Model transformer with a token classification head on top (a linear layer on top of the hidden-states - output) e.g. for Named-Entity-Recognition (NER) tasks. - """, - QWEN2MOE_START_DOCSTRING, -) -# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2Moe, LLAMA->QWEN2MOE -class Qwen2MoeForTokenClassification(Qwen2MoePreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = Qwen2MoeModel(config) - if getattr(config, "classifier_dropout", None) is not None: - classifier_dropout = config.classifier_dropout - elif getattr(config, "hidden_dropout", None) is not None: - classifier_dropout = config.hidden_dropout - else: - classifier_dropout = 0.1 - self.dropout = nn.Dropout(classifier_dropout) - self.score = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - sequence_output = outputs[0] - sequence_output = self.dropout(sequence_output) - logits = self.score(sequence_output) - - loss = None - if labels is not None: - loss = self.loss_function(logits, labels, self.config) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ -The Qwen2MoE Model transformer with a span classification head on top for extractive question-answering tasks like -SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - QWEN2MOE_START_DOCSTRING, -) -# Copied from transformers.models.mistral.modeling_mistral.MistralForQuestionAnswering with Mistral->Qwen2Moe, MISTRAL->QWEN2MOE -class Qwen2MoeForQuestionAnswering(Qwen2MoePreTrainedModel): - base_model_prefix = "model" - - def __init__(self, config): - super().__init__(config) - self.qa_outputs = nn.Linear(config.hidden_size, 2) - self.model = Qwen2MoeModel(config) # diff with Llama: transformer->model - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - start_positions: Optional[torch.LongTensor] = None, - end_positions: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs, - ) -> Union[Tuple, QuestionAnsweringModelOutput]: - r""" - start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1).contiguous() - end_logits = end_logits.squeeze(-1).contiguous() - - loss = None - if start_positions is not None and end_positions is not None: - loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) - - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return QuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -__all__ = [ - "Qwen2MoeForCausalLM", - "Qwen2MoeForQuestionAnswering", - "Qwen2MoeModel", - "Qwen2MoePreTrainedModel", - "Qwen2MoeForSequenceClassification", - "Qwen2MoeForTokenClassification", -] \ No newline at end of file From e02665826e6b0f70e52f0c52ba470bd4440656b2 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 25 Feb 2025 12:28:58 -0500 Subject: [PATCH 20/22] add debug statement Signed-off-by: Kyle Sayers --- src/llmcompressor/transformers/utils/helpers.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/transformers/utils/helpers.py b/src/llmcompressor/transformers/utils/helpers.py index c315ae1fb..cddd45d4f 100644 --- a/src/llmcompressor/transformers/utils/helpers.py +++ b/src/llmcompressor/transformers/utils/helpers.py @@ -142,7 +142,14 @@ def recipe_from_huggingface_model_id( try: recipe = hf_hub_download(repo_id=hf_stub, filename=recipe_file_name) logger.info(f"Found recipe: {recipe_file_name} for model ID: {hf_stub}.") - except Exception: # TODO: narrow acceptable exceptions + except Exception as e: # TODO: narrow acceptable exceptions + logger.debug( + ( + f"Unable to find recipe {recipe_file_name} " + f"for model ID: {hf_stub}: {e}." + "Skipping recipe resolution." + ) + ) recipe = None return recipe From 00961a07e2815f0bf56df373d91a5c2e08607d7b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 25 Feb 2025 13:48:36 -0500 Subject: [PATCH 21/22] pass model to stage runner Signed-off-by: Kyle Sayers --- src/llmcompressor/entrypoints/oneshot.py | 1 + src/llmcompressor/transformers/finetune/runner.py | 6 ++++-- src/llmcompressor/transformers/finetune/text_generation.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index 9c57b423a..73e743470 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -254,6 +254,7 @@ def _pre_process(self): if isinstance(self.model_args.model, (str, PosixPath)): self.model_args.model, _ = initialize_model_from_path(self.model_args) + breakpoint() patch_tied_tensors_bug(self.model_args.model) modify_save_pretrained(self.model_args.model) diff --git a/src/llmcompressor/transformers/finetune/runner.py b/src/llmcompressor/transformers/finetune/runner.py index bee682bda..1735a99b8 100644 --- a/src/llmcompressor/transformers/finetune/runner.py +++ b/src/llmcompressor/transformers/finetune/runner.py @@ -6,6 +6,7 @@ import torch from loguru import logger from torch.utils.data import Dataset +from transformers import PreTrainedModel from llmcompressor.args import ( DatasetArguments, @@ -154,7 +155,9 @@ def train(self, checkpoint: str, stage: Optional[str] = None): # this includes saving the state, optimizer and scheduler self.trainer.save_model(output_dir=self._output_dir) - def run_sequential_stages(self, checkpoint: Optional[str] = None): + def run_sequential_stages( + self, model: PreTrainedModel, checkpoint: Optional[str] = None + ): """ Run the recipe stage by stage, allowing for alternating between one-shot and finetuning flows. Optionally save the model output at the end of each stage @@ -195,7 +198,6 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None): if run_type is StageRunType.ONESHOT: from llmcompressor import Oneshot - model = get_session_model() self._model_args.model = model oneshot = Oneshot.from_args( diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index 20616e7c1..9a3623f60 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -422,7 +422,7 @@ def main( checkpoint = None if last_checkpoint is not None: checkpoint = last_checkpoint - stage_runner.run_sequential_stages(checkpoint) + stage_runner.run_sequential_stages(model, checkpoint) # exit immediately return From c7678b03156fba69e5e0da75b4b03dfe19f72607 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 26 Feb 2025 10:34:00 -0500 Subject: [PATCH 22/22] remove breakpoint Signed-off-by: Kyle Sayers --- src/llmcompressor/entrypoints/oneshot.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index 73e743470..9c57b423a 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -254,7 +254,6 @@ def _pre_process(self): if isinstance(self.model_args.model, (str, PosixPath)): self.model_args.model, _ = initialize_model_from_path(self.model_args) - breakpoint() patch_tied_tensors_bug(self.model_args.model) modify_save_pretrained(self.model_args.model)