Skip to content

Commit

Permalink
Merge branch 'main' into kylesayrs/remove-leave_enabled
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs authored Feb 26, 2025
2 parents 2bd0f6e + 29ddedb commit 54eb85c
Show file tree
Hide file tree
Showing 24 changed files with 192 additions and 471 deletions.
34 changes: 33 additions & 1 deletion .github/workflows/test-check-transformers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,41 @@ env:
CLEARML_API_SECRET_KEY: ${{ secrets.CLEARML_API_SECRET_KEY }}

jobs:
detect-changes:
runs-on: ubuntu-latest

outputs:
changes-present: ${{ steps.changed-files.outputs.any_modified }}

steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Get changed files
id: changed-files
uses: tj-actions/changed-files@v45
with:
files: |
**
!examples/**
!tests/e2e/**
!tests/lmeval/**
!tests/examples/**
!**/*.md
!.github/**
.github/workflows/test-check-transformers.yaml
- name: Log relevant output
run: |
echo "changes-present: ${{ steps.changed-files.outputs.any_modified }}"
echo "all modified files: ${{ steps.changed-files.outputs.all_modified_files }}"
shell: bash

transformers-tests:
needs: [detect-changes]
runs-on: gcp-k8s-vllm-l4-solo
if: contains(github.event.pull_request.labels.*.name, 'ready') || github.event_name == 'push'
if: (contains(github.event.pull_request.labels.*.name, 'ready') || github.event_name == 'push') && needs.detect-changes.outputs.changes-present == 'true'
steps:
- uses: actions/setup-python@v5
with:
Expand Down
1 change: 0 additions & 1 deletion src/llmcompressor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
create_session,
finalize,
initialize,
pre_initialize_structure,
reset_session,
)
from llmcompressor.entrypoints import Oneshot, oneshot
2 changes: 0 additions & 2 deletions src/llmcompressor/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
create_session,
finalize,
initialize,
pre_initialize_structure,
reset_session,
)
from llmcompressor.core.state import Data, Hardware, ModifiedState, State
Expand All @@ -36,7 +35,6 @@
"create_session",
"active_session",
"reset_session",
"pre_initialize_structure",
"initialize",
"finalize",
"apply",
Expand Down
31 changes: 0 additions & 31 deletions src/llmcompressor/core/events/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class EventType(Enum):
The purpose of each EventType is to trigger the corresponding
modifier callback during training or post training pipelines.
:param PRE_INIT: Event type for pre-initialization.
:param INITIALIZE: Event type for initialization.
:param FINALIZE: Event type for finalization.
:param BATCH_START: Event type for the start of a batch.
Expand All @@ -38,7 +37,6 @@ class EventType(Enum):
"""

# training lifecycle
PRE_INIT = "pre_init"
INITIALIZE = "initialize"
FINALIZE = "finalize"

Expand All @@ -51,35 +49,6 @@ class EventType(Enum):
OPTIM_PRE_STEP = "optim_pre_step"
OPTIM_POST_STEP = "optim_post_step"

def order(self) -> int:
"""
Returns the priority order of the current EventType.
Lower values have higher priority.
:raises ValueError: if the event type is invalid.
:return: The order of the event type, lower has higher priority.
:rtype: int
"""
if self == EventType.PRE_INIT:
return 0
elif self == EventType.INITIALIZE:
return 10
elif self == EventType.FINALIZE:
return 20
elif self == EventType.BATCH_START:
return 100
elif self == EventType.LOSS_CALCULATED:
return 110
elif self == EventType.OPTIM_PRE_STEP:
return 120
elif self == EventType.OPTIM_POST_STEP:
return 130
elif self == EventType.BATCH_END:
return 140
else:
logger.error("Invalid event type: {}", self)
raise ValueError(f"Invalid event type {self}")


@dataclass
class Event:
Expand Down
95 changes: 24 additions & 71 deletions src/llmcompressor/core/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
)
from llmcompressor.core.state import State
from llmcompressor.modifiers import StageModifiers
from llmcompressor.recipe import RecipeContainer
from llmcompressor.recipe import (
RecipeArgsInput,
RecipeContainer,
RecipeInput,
RecipeStageInput,
)

__all__ = ["CompressionLifecycle"]

Expand All @@ -38,7 +43,7 @@ class CompressionLifecycle:
:type event_lifecycle: Optional[EventLifecycle]
"""

state: Optional[State] = None
state: State = field(default_factory=State)
recipe_container: RecipeContainer = field(default_factory=RecipeContainer)
modifiers: List[StageModifiers] = field(default_factory=list)
event_lifecycle: Optional[EventLifecycle] = None
Expand All @@ -62,63 +67,35 @@ def reset(self):
except Exception as e:
logger.warning(f"Exception during finalizing modifier: {e}")

self.state = None
self.recipe_container = RecipeContainer()
self.modifiers = []
self.event_lifecycle = None

self.initialized_ = False
self.finalized = False
self.__init__()
logger.info("Compression lifecycle reset")

def pre_initialize_structure(self, **kwargs) -> List[Any]:
"""
Pre-initialize the structure of the compression lifecycle.
:param kwargs: Additional arguments to update the state with
:return: List of data returned from pre-initialization of modifiers
:rtype: List[Any]
"""
logger.debug("Pre-initializing structure")
self._check_create_state()
extras = self.state.update(**kwargs)
extras = self.recipe_container.update(**extras)

self._check_compile_recipe()
mod_data = []
for mod in self.modifiers:
data = mod.pre_initialize_structure(state=self.state, **extras)
logger.debug("Pre-initialized modifier: {}", mod)
if data is not None:
mod_data.append(data)

applied_stage_names = [mod.unique_id for mod in self.modifiers if mod.applied]
self.recipe_container.update_applied_stages(applied_stage_names)
logger.info(
"Compression lifecycle structure pre-initialized for {} modifiers",
len(self.modifiers),
)

return mod_data

def initialize(self, **kwargs) -> List[Any]:
def initialize(
self,
recipe: Optional[RecipeInput] = None,
recipe_stage: Optional[RecipeStageInput] = None,
recipe_args: Optional[RecipeArgsInput] = None,
**kwargs,
) -> List[Any]:
"""
Initialize the compression lifecycle.
:param kwargs: Additional arguments to update the state with
:return: List of data returned from initialization of modifiers
:rtype: List[Any]
"""
logger.debug("Initializing compression lifecycle")
self._check_create_state()
extras = self.state.update(**kwargs)
extras = self.recipe_container.update(**extras)
self.state.update(**kwargs)
if self.initialized_: # TODO: do not initialize twice
return

self._check_compile_recipe()
logger.debug("Initializing compression lifecycle")
self.recipe_container.append(recipe, recipe_stage, recipe_args)
self.modifiers = self.recipe_container.get_modifiers()
self._set_model_layer_prefix()

mod_data = []
for mod in self.modifiers:
data = mod.initialize(state=self.state, **extras)
data = mod.initialize(state=self.state, **kwargs)
logger.debug("Initialized modifier: {}", mod)
if data is not None:
mod_data.append(data)
Expand Down Expand Up @@ -185,7 +162,7 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]:
logger.error("Cannot invoke event after finalizing")
raise ValueError("Cannot invoke event after finalizing")

if event_type in [EventType.PRE_INIT, EventType.INITIALIZE, EventType.FINALIZE]:
if event_type in [EventType.INITIALIZE, EventType.FINALIZE]:
logger.error(
"Cannot invoke {} event. Use the corresponding method instead.",
event_type,
Expand Down Expand Up @@ -223,30 +200,6 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]:

return mod_data

def _check_create_state(self):
if self.state is not None:
return

logger.debug("Creating new State instance for compression lifecycle")
self.state = State()
logger.info("State created for compression lifecycle")

def _check_compile_recipe(self):
if not self.recipe_container.check_compile_recipe():
return

logger.debug(
"Compiling recipe and creating modifiers for compression lifecycle"
)
self.modifiers = self.recipe_container.compiled_recipe.create_modifier()
for mod in self.modifiers:
if mod.unique_id in self.recipe_container.applied_stages:
mod.applied = True
logger.info(
"Recipe compiled and {} modifiers created",
len(self.modifiers),
)

def _check_setup_event_lifecycle(self, event_type: EventType):
if self.event_lifecycle is not None:
return
Expand Down
39 changes: 0 additions & 39 deletions src/llmcompressor/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,45 +65,6 @@ def state(self) -> State:
"""
return self._lifecycle.state

def pre_initialize_structure(
self,
model: Any,
recipe: Union[str, List[str], Recipe, List[Recipe], None] = None,
recipe_stage: Union[str, List[str], None] = None,
recipe_args: Union[Dict[str, Any], List[Dict[str, Any]], None] = None,
**kwargs,
) -> ModifiedState:
"""
A method to pre-initialize the structure of the model for compression.
This will run the pre-initialize structure method for each modifier in the
session's lifecycle. This will also set the session's state to the
pre-initialized state. Takes care of cases when the model(s) structure
has been previously modified by a modifier.
:param model: the model to pre-initialize the structure for
:param recipe: the recipe to use for the compression, can be a path to a
recipe file, a raw recipe string, a recipe object, or a list
of recipe objects.
:param recipe_stage: the stage to use for the compression
:param recipe_args: the args to use for overriding the recipe defaults
:return: A ModifiedState instance holding the modified model and modifier_data
after pre-initializing the structure
"""
mod_data = self._lifecycle.pre_initialize_structure(
model=model,
recipe=recipe,
recipe_stage=recipe_stage,
recipe_args=recipe_args,
**kwargs,
)

return ModifiedState(
model=self.state.model,
optimizer=None,
loss=None,
modifier_data=mod_data,
)

def initialize(
self,
recipe: Union[str, List[str], "Recipe", List["Recipe"], None] = None,
Expand Down
13 changes: 1 addition & 12 deletions src/llmcompressor/core/session_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
"create_session",
"active_session",
"reset_session",
"pre_initialize_structure",
"initialize",
"finalize",
"callbacks",
Expand Down Expand Up @@ -59,16 +58,6 @@ def reset_session():
session._lifecycle.reset()


def pre_initialize_structure(**kwargs):
"""
A method to pre-initialize the structure of the model for the active session
:param kwargs: the kwargs to pass to the active session's pre-initialize-structure
method
"""
active_session().pre_initialize_structure(**kwargs)


def initialize(
recipe: Union[str, List[str], "Recipe", List["Recipe"], None] = None,
recipe_stage: Union[str, List[str], None] = None,
Expand Down Expand Up @@ -156,7 +145,7 @@ def event(cls, event_type: EventType, **kwargs) -> ModifiedState:
:param kwargs: additional kwargs to pass to the current session's event method
:return: the modified state of the active session after invoking the event
"""
if event_type in [EventType.PRE_INIT, EventType.INITIALIZE, EventType.FINALIZE]:
if event_type in [EventType.INITIALIZE, EventType.FINALIZE]:
raise ValueError(
f"Cannot invoke {event_type} event. "
f"Use the corresponding method instead."
Expand Down
19 changes: 0 additions & 19 deletions src/llmcompressor/modifiers/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,6 @@ class ModifierInterface(ABC):
Defines the contract that all modifiers must implement
"""

@property
@abstractmethod
def initialized_structure(self) -> bool:
"""
:return: True if the modifier structure has been
applied to the model
"""
raise NotImplementedError()

@property
@abstractmethod
def initialized(self) -> bool:
Expand Down Expand Up @@ -58,16 +49,6 @@ def calculate_end(self) -> float:
"""
raise NotImplementedError()

@abstractmethod
def pre_initialize_structure(self, state: State, **kwargs):
"""
Apply the modifier structure to the model
:param state: The current state of the model
:param kwargs: Additional arguments for the modifier
"""
raise NotImplementedError()

@abstractmethod
def initialize(self, state: State, **kwargs):
"""
Expand Down
Loading

0 comments on commit 54eb85c

Please sign in to comment.