diff --git a/angelslim/compressor/speculative/train/trainer/eagle3_trainer.py b/angelslim/compressor/speculative/train/trainer/eagle3_trainer.py index 8526e70b..ee19852f 100644 --- a/angelslim/compressor/speculative/train/trainer/eagle3_trainer.py +++ b/angelslim/compressor/speculative/train/trainer/eagle3_trainer.py @@ -47,10 +47,10 @@ def __init__(self, draft_model: nn.Module, length: int, **kwargs): super().__init__(model=draft_model, **kwargs) self.length = length self._train_start_time = None - self._pending_log: dict = ( - {} - ) # cache acc/ploss log for merging with base Trainer's loss log - self._pending_log_count: int = 0 # accumulated batch count for averaging the cached log + self._train_pending_log: dict = {} + self._train_pending_log_count: int = 0 + self._eval_pending_log: dict = {} + self._eval_pending_log_count: int = 0 def train(self, *args, **kwargs): """Override train method to record training start time for estimating remaining time.""" @@ -59,12 +59,11 @@ def train(self, *args, **kwargs): def log(self, logs: dict, start_time: Optional[float] = None) -> None: """ - rewrite log method to merge acc/ploss log with base Trainer's loss log. + Merge acc/ploss accumulators with the base Trainer's loss log. """ - if "loss" in logs and self._pending_log: - # merge cached acc/ploss data (average) - count = max(self._pending_log_count, 1) - acc_ploss = {k: v / count for k, v in self._pending_log.items()} + if "loss" in logs and self._train_pending_log: + train_count = max(self._train_pending_log_count, 1) + acc_ploss = {k: v / train_count for k, v in self._train_pending_log.items()} merged = {} # step @@ -85,9 +84,16 @@ def log(self, logs: dict, start_time: Optional[float] = None) -> None: if "learning_rate" in logs: merged["lr"] = logs["learning_rate"] - # acc/ploss + # train acc/ploss merged.update(acc_ploss) + # eval acc/ploss — merged when a training log fires + if self._eval_pending_log: + eval_count = max(self._eval_pending_log_count, 1) + merged.update({k: v / eval_count for k, v in self._eval_pending_log.items()}) + self._eval_pending_log.clear() + self._eval_pending_log_count = 0 + # remaining_time if ( self.state is not None @@ -102,8 +108,8 @@ def log(self, logs: dict, start_time: Optional[float] = None) -> None: minutes, seconds = divmod(remainder, 60) merged["remaining_time"] = f"{hours:02d}h:{minutes:02d}m:{seconds:02d}s" - self._pending_log.clear() - self._pending_log_count = 0 + self._train_pending_log.clear() + self._train_pending_log_count = 0 super().log(merged, start_time) else: super().log(logs, start_time) @@ -294,10 +300,15 @@ def draft_model_training_time_test( log = {f"{log_prefix}/acc_{i}": acces[i] for i in range(len(acces))} log.update({f"{log_prefix}/ploss_{i}": plosses[i].item() for i in range(len(plosses))}) - # Cache log for merging with base Trainer's loss log - for k, v in log.items(): - self._pending_log[k] = self._pending_log.get(k, 0.0) + v - self._pending_log_count += 1 + # Route into the appropriate accumulator. + if log_prefix == "eval": + for k, v in log.items(): + self._eval_pending_log[k] = self._eval_pending_log.get(k, 0.0) + v + self._eval_pending_log_count += 1 + else: + for k, v in log.items(): + self._train_pending_log[k] = self._train_pending_log.get(k, 0.0) + v + self._train_pending_log_count += 1 # Step 9: Return loss return ploss