Skip to content

Commit

Permalink
add validation, global step
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle Sayers <[email protected]>
  • Loading branch information
kylesayrs committed Mar 12, 2025
1 parent 75657a5 commit 0377ffd
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 32 deletions.
48 changes: 46 additions & 2 deletions src/llmcompressor/core/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,21 @@ class CompressionLifecycle:
initialized_: bool = False
finalized: bool = False

# event order validation
_last_event_type: Optional[EventType] = EventType.BATCH_END
_event_order: List[EventType] = field(
default_factory=lambda: [
EventType.BATCH_START,
EventType.LOSS_CALCULATED,
EventType.OPTIM_PRE_STEP,
EventType.OPTIM_POST_STEP,
EventType.BATCH_END,
]
)

# track global step in training (could be epoch/batch)
global_step: int = 0

def reset(self):
"""
Reset the compression lifecycle, finalizing any active modifiers
Expand Down Expand Up @@ -134,7 +149,9 @@ def finalize(self, **kwargs) -> List[Any]:

return mod_data

def event(self, event_type: EventType, **kwargs) -> List[Any]:
def event(
self, event_type: EventType, global_step: Optional[int] = 0, **kwargs
) -> List[Any]:
"""
Handle a compression event.
Expand Down Expand Up @@ -164,6 +181,12 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]:
f"Use the corresponding method instead."
)

if not self._validate_event_order(event_type):
raise ValueError(
f"Lifecycle events must appear following order: {self._event_order}. "
f"Instead, {self._last_event_type} was called before {event_type}"
)

if event_type == EventType.LOSS_CALCULATED and (
"loss" not in kwargs or kwargs["loss"] is None
):
Expand All @@ -172,7 +195,11 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]:

logger.debug("Handling event: {}", event_type)

event = Event(event_type=event_type)
# update global step
if global_step is not None:
self.global_step = global_step

event = Event(type_=event_type)
mod_data = []
for mod in self.modifiers:
data = mod.update_event(state=self.state, event=event, **kwargs)
Expand All @@ -186,6 +213,23 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]:

return mod_data

def _validate_event_order(self, event_type: EventType) -> bool:
if event_type not in self._event_order:
# for unhandled events, do not save last event
return True

if event_type == EventType.BATCH_START:
valid = self._last_event_type != EventType.BATCH_START

else:
last_event_index = self._event_order.index(self._last_event_type)
curr_event_index = self._event_order.index(event_type)
valid = last_event_index <= curr_event_index

if valid:
self._last_event_type = event_type
return valid

def _set_model_layer_prefix(self):
compiled_recipe = self.recipe_container.compiled_recipe
if (
Expand Down
35 changes: 6 additions & 29 deletions src/llmcompressor/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,57 +223,34 @@ def get_serialized_recipe(self) -> Optional[str]:

def _log_model_info(self):
# Log model level logs if cadence reached
event_lifecycle = self._lifecycle.event_lifecycle
if event_lifecycle is None:
# event lifecycle not available
# when recipe is not provided
return

epoch = event_lifecycle.current_index
current_index = self._lifecycle.global_step

if (
should_log_model_info(
model=self.state.model,
loggers=self.state.loggers,
current_log_step=epoch,
current_log_step=current_index,
last_log_step=self.state._last_log_step,
)
and self.state.loggers.frequency_manager.is_epoch_frequency_manager
):
log_model_info(
state=self.state,
current_log_step=epoch,
current_log_step=current_index,
)
# update last log epoch
self.state.loggers.log_written(epoch)
self.state.loggers.log_written(current_index)

def _log_loss(self, event_type: EventType, loss: Any):
if event_type != EventType.LOSS_CALCULATED:
# only log loss when loss is calculated
return
event_lifecycle = self._lifecycle.event_lifecycle

if event_lifecycle is None:
# event lifecycle not available
# when recipe is not provided
return

epoch = event_lifecycle.current_index
if self.state.loggers.frequency_manager.is_optim_frequency_manager:
# log integer step for optimizer frequency manager
current_step = int(
self.state.loggers.epoch_to_step(
epoch=epoch,
steps_per_epoch=len(self.state.data.train),
)
)
else:
# log float epoch for epoch frequency manager
current_step = epoch
current_index = self._lifecycle.global_step

# always log loss if available
if loss is not None:
loss = loss if isinstance(loss, dict) else {"loss": loss}
self.state.loggers.metric.log_scalars(
tag="Loss", values=loss, step=current_step
tag="Loss", values=loss, step=current_index
)
2 changes: 1 addition & 1 deletion src/llmcompressor/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def training_step(
"""
self._check_super_defined("training_step")

callbacks.batch_start(batch_data=inputs)
callbacks.batch_start(batch_data=inputs, global_step=self.state.epoch)
model_outputs = super().training_step(
model=model, inputs=inputs, num_items_in_batch=num_items_in_batch
)
Expand Down

0 comments on commit 0377ffd

Please sign in to comment.