diff --git a/angelslim/compressor/speculative/__init__.py b/angelslim/compressor/speculative/__init__.py index f32ed4b5..b108957b 100644 --- a/angelslim/compressor/speculative/__init__.py +++ b/angelslim/compressor/speculative/__init__.py @@ -17,7 +17,9 @@ DatasetManager, DraftModelConfig, Eagle3TrainerFactory, + GaussianNoise, TargetHead, + TransformDataset, convert_sharegpt_data, convert_ultrachat_data, create_draft_model, @@ -39,6 +41,8 @@ "convert_sharegpt_data", "convert_ultrachat_data", "DatasetManager", + "GaussianNoise", + "TransformDataset", "get_supported_chat_template_type_strings", "TargetHead", "infer_model_params", diff --git a/angelslim/compressor/speculative/train/__init__.py b/angelslim/compressor/speculative/train/__init__.py index c9ecfd9b..b2b67165 100644 --- a/angelslim/compressor/speculative/train/__init__.py +++ b/angelslim/compressor/speculative/train/__init__.py @@ -1,5 +1,7 @@ from .data import ( DatasetManager, + GaussianNoise, + TransformDataset, convert_sharegpt_data, convert_ultrachat_data, data_generation_work_flow, @@ -23,6 +25,8 @@ "convert_sharegpt_data", "convert_ultrachat_data", "DatasetManager", + "GaussianNoise", + "TransformDataset", "get_supported_chat_template_type_strings", "TargetHead", "infer_model_params", diff --git a/angelslim/compressor/speculative/train/data/__init__.py b/angelslim/compressor/speculative/train/data/__init__.py index d1ae863d..194bdd80 100644 --- a/angelslim/compressor/speculative/train/data/__init__.py +++ b/angelslim/compressor/speculative/train/data/__init__.py @@ -16,9 +16,12 @@ from .data_generation import data_generation_work_flow from .data_utils import convert_sharegpt_data, convert_ultrachat_data from .dataset import DatasetManager +from .noise_transforms import GaussianNoise, TransformDataset __all__ = [ "DatasetManager", + "GaussianNoise", + "TransformDataset", "convert_sharegpt_data", "convert_ultrachat_data", "data_generation_work_flow", diff --git a/angelslim/compressor/speculative/train/data/noise_transforms.py b/angelslim/compressor/speculative/train/data/noise_transforms.py new file mode 100644 index 00000000..2a1a1d38 --- /dev/null +++ b/angelslim/compressor/speculative/train/data/noise_transforms.py @@ -0,0 +1,86 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hidden-state transform utilities for offline EAGLE3 training. + +Usage in a training script:: + + from angelslim.compressor.speculative.train.data.noise_transforms import ( + GaussianNoise, + TransformDataset, + ) + + train_dataset = TransformDataset(base_train_dataset, GaussianNoise(std=0.05)) + # eval_dataset is passed through unchanged + +Calibrating noise std +--------------------- +To verify the noise/signal ratio before committing to a full run, use +tools/test_dataloader.py:: + + python3 tools/test_dataloader.py \\ + --train-hidden-path $TRAIN_HIDDEN_PATH \\ + --eval-hidden-path $EVAL_HIDDEN_PATH \\ + --target-model $TARGET_MODEL_NAME_OR_PATH \\ + --hidden-noise-std 0.05 +""" + +import torch +from torch.utils.data import Dataset + + +class TransformDataset(Dataset): + """Wraps an offline dataset and applies a callable transform at ``__getitem__`` time. + + The transform is applied to a shallow copy of each sample so the underlying + cache is not mutated when ``cache_in_memory=True`` is used. + + Args: + dataset: The underlying offline dataset (e.g. OfflineVLMEagle3Dataset). + transform: Callable that takes a sample dict and returns a modified dict. + """ + + def __init__(self, dataset: Dataset, transform) -> None: + self.dataset = dataset + self.transform = transform + + def __len__(self) -> int: + return len(self.dataset) + + def __getitem__(self, idx: int) -> dict: + # Shallow copy prevents mutating the underlying cache when the inner + # dataset uses cache_in_memory=True. + sample = dict(self.dataset[idx]) + return self.transform(sample) + + +class GaussianNoise: + """Adds Gaussian noise to a tensor field of a sample dict. + + Noise is sampled fresh each call, so the model sees a different perturbation + of every sample on every epoch — effectively multiplying dataset diversity + without generating new hidden states. + + Args: + std: Standard deviation of the Gaussian noise. Set to 0.0 to disable. + field: Key of the tensor field to perturb (default: ``"hidden_states"``). + """ + + def __init__(self, std: float, field: str = "hidden_states") -> None: + self.std = std + self.field = field + + def __call__(self, sample: dict) -> dict: + sample[self.field] = sample[self.field] + torch.randn_like(sample[self.field]) * self.std + return sample diff --git a/scripts/speculative/hunyuan_ocr/train_eagle3_vlm_offline.sh b/scripts/speculative/hunyuan_ocr/train_eagle3_vlm_offline.sh index 11393853..e2e80853 100644 --- a/scripts/speculative/hunyuan_ocr/train_eagle3_vlm_offline.sh +++ b/scripts/speculative/hunyuan_ocr/train_eagle3_vlm_offline.sh @@ -8,6 +8,7 @@ EVAL_HIDDEN_PATH= OUTPUT_DIR= RUN_NAME=hunyuan-ocr-eagle3-angelslim MODEL_MAX_LENGTH=8192 +HIDDEN_NOISE_STD=0.0 export MAX_PIXELS= torchrun --nproc_per_node=8 tools/train_eagle3_offline.py \ @@ -36,4 +37,5 @@ torchrun --nproc_per_node=8 tools/train_eagle3_offline.py \ --run_name $RUN_NAME \ --num_proc 8 \ --training_time_test_length 4 \ + --hidden_noise_std $HIDDEN_NOISE_STD \ --bf16 diff --git a/scripts/speculative/qwen3_vl/train_eagle3_vlm_offline.sh b/scripts/speculative/qwen3_vl/train_eagle3_vlm_offline.sh index eebb6c82..13a018df 100644 --- a/scripts/speculative/qwen3_vl/train_eagle3_vlm_offline.sh +++ b/scripts/speculative/qwen3_vl/train_eagle3_vlm_offline.sh @@ -8,6 +8,7 @@ EVAL_HIDDEN_PATH= OUTPUT_DIR= RUN_NAME=qwen3-4b-eagle3-angelslim MODEL_MAX_LENGTH=8192 +HIDDEN_NOISE_STD=0.0 export MAX_PIXELS= torchrun --nproc_per_node=8 tools/train_eagle3_offline.py \ @@ -36,4 +37,5 @@ torchrun --nproc_per_node=8 tools/train_eagle3_offline.py \ --run_name $RUN_NAME \ --num_proc 8 \ --training_time_test_length 3 \ + --hidden_noise_std $HIDDEN_NOISE_STD \ --bf16 diff --git a/scripts/speculative/train_eagle3_offline.sh b/scripts/speculative/train_eagle3_offline.sh index 03ecc1ac..0b9654f2 100644 --- a/scripts/speculative/train_eagle3_offline.sh +++ b/scripts/speculative/train_eagle3_offline.sh @@ -11,6 +11,7 @@ export RUN_NAME= export MODEL_MAX_LENGTH=4096 export LM_HEAD_KEY= export CHAT_TEMPLATE_TYPE=qwen3 +export HIDDEN_NOISE_STD=0.0 torchrun --nproc_per_node=8 tools/train_eagle3_offline.py \ --target_model_name_or_path $TARGET_MODEL_NAME_OR_PATH \ @@ -37,4 +38,5 @@ torchrun --nproc_per_node=8 tools/train_eagle3_offline.py \ --report_to wandb \ --run_name $RUN_NAME \ --num_proc 48 \ + --hidden_noise_std $HIDDEN_NOISE_STD \ --bf16 \ No newline at end of file diff --git a/scripts/speculative/train_eagle3_vlm_offline.sh b/scripts/speculative/train_eagle3_vlm_offline.sh index 62e4e7ec..f715853b 100644 --- a/scripts/speculative/train_eagle3_vlm_offline.sh +++ b/scripts/speculative/train_eagle3_vlm_offline.sh @@ -11,6 +11,7 @@ export RUN_NAME= export MODEL_MAX_LENGTH=4096 export LM_HEAD_KEY= export CHAT_TEMPLATE_TYPE=qwen3_vl +export HIDDEN_NOISE_STD=0.0 torchrun --nproc_per_node=8 tools/train_eagle3_offline.py \ --modal_type VLM \ @@ -38,4 +39,5 @@ torchrun --nproc_per_node=8 tools/train_eagle3_offline.py \ --report_to wandb \ --run_name $RUN_NAME \ --num_proc 48 \ + --hidden_noise_std $HIDDEN_NOISE_STD \ --bf16 \ No newline at end of file diff --git a/tools/test_dataloader.py b/tools/test_dataloader.py new file mode 100644 index 00000000..27f8a7ea --- /dev/null +++ b/tools/test_dataloader.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test hidden-state dataloader: report per-batch statistics and calibrate noise std. + +Loads one batch from each of the train and eval datasets (using the same +DatasetManager as training) and prints hidden_states shape, std, mean, and +abs-max. When --hidden-noise-std is set, also shows the noise/signal ratio +so you can calibrate the noise level before committing to a full run. + +Loads the tokenizer and .ckpt hidden state files only. + +Example: + python3 tools/test_dataloader.py \\ + --train-hidden-path /path/to/hidden_states_train \\ + --eval-hidden-path /path/to/hidden_states_eval \\ + --target-model /path/to/model/snapshot \\ + --hidden-noise-std 0.05 +""" + +import argparse +import sys +import types + +import torch +from torch.utils.data import DataLoader +from transformers import AutoProcessor + +from angelslim.compressor.speculative import ( + DatasetManager, + GaussianNoise, + TransformDataset, + infer_model_params, +) + + +def parse_args(): + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--train-hidden-path", + required=True, + help="Path to training hidden states directory (same as TRAIN_HIDDEN_PATH in training).", + ) + parser.add_argument( + "--eval-hidden-path", + required=True, + help="Path to eval hidden states directory (same as EVAL_HIDDEN_PATH in training).", + ) + parser.add_argument( + "--target-model", + required=True, + help=( + "Local path to the target model snapshot. Used only to load the tokenizer " + "and infer chat_template_type — no model weights are loaded." + ), + ) + parser.add_argument( + "--hidden-noise-std", + type=float, + default=0.0, + help=( + "Noise std to test (same value as --hidden_noise_std in training). " + "0.0 = no augmentation (default). Reports noise/signal ratio when > 0." + ), + ) + parser.add_argument( + "--chat-template-type", + default=None, + help=( + "Chat template type (e.g. 'qwen3_vl'). Auto-detected from --target-model " + "when not set. Use this to override if auto-detection fails on local paths." + ), + ) + parser.add_argument( + "--model-max-length", + type=int, + default=16384, + help="Maximum sequence length passed to DatasetManager (default: 16384).", + ) + parser.add_argument( + "--target-model-type", + default="qwen3_vl", + help=( + "Target model type passed to DatasetManager for VLM collator selection " + "(default: qwen3_vl). Override when testing against a different VLM " + "(e.g. qwen2_5_vl, hunyuan_vl)." + ), + ) + return parser.parse_args() + + +def _stats(hs: torch.Tensor, label: str) -> float: + f = hs.float() + std = f.std().item() + print(f" {label}") + print(f" shape : {tuple(f.shape)}") + print(f" dtype : {hs.dtype}") + print(f" std : {std:.6f}") + print(f" mean : {f.mean().item():.6f}") + print(f" absmax : {f.abs().max().item():.6f}") + return std + + +def _make_data_args(args, chat_template_type: str) -> types.SimpleNamespace: + """Build a minimal data_args namespace for DatasetManager.""" + return types.SimpleNamespace( + training_mode="offline", + modal_type="VLM", + train_hidden_path=args.train_hidden_path, + eval_hidden_path=args.eval_hidden_path, + target_model_name_or_path=args.target_model, + train_data_path=None, + sample_num=None, + shuffle_seed=42, + display=False, + num_proc=4, + chat_template_type=chat_template_type, + ) + + +def main(): + args = parse_args() + + # Resolve chat_template_type — use explicit override if provided, else infer. + if args.chat_template_type: + chat_template_type = args.chat_template_type + print(f"chat_template_type : {chat_template_type} (from --chat-template-type)") + else: + print(f"Inferring model params from {args.target_model} ...") + _, _, chat_template_type = infer_model_params( + model_name_or_path=args.target_model, model_type=None + ) + if chat_template_type is None: + print( + "ERROR: could not infer chat_template_type. " + "Pass --chat-template-type explicitly (e.g. qwen3_vl).", + file=sys.stderr, + ) + sys.exit(1) + print(f" chat_template_type: {chat_template_type}") + + # Load tokenizer only (no model weights). + print("Loading tokenizer ...") + tokenizer = AutoProcessor.from_pretrained(args.target_model) + + # Build datasets via the same DatasetManager used in training. + print("Building datasets ...") + data_args = _make_data_args(args, chat_template_type) + dataset_manager = DatasetManager( + data_args=data_args, + tokenizer=tokenizer, + model_max_length=args.model_max_length, + chat_template_type=chat_template_type, + target_model_type=args.target_model_type, + ) + train_dataset, eval_dataset, data_collator = dataset_manager.create_offline_datasets() + print(f" train dataset size : {len(train_dataset)}") + print(f" eval dataset size : {len(eval_dataset) if eval_dataset else 'N/A'}") + + # Eval batch — no noise, measures raw signal scale. + raw_std = None + if eval_dataset: + print("\nEval batch (no noise):") + eval_loader = DataLoader(eval_dataset, batch_size=1, collate_fn=data_collator) + eval_batch = next(iter(eval_loader)) + raw_std = _stats(eval_batch["hidden_states"], "hidden_states") + + # Train batch — with noise if requested. + if args.hidden_noise_std > 0.0: + aug_dataset = TransformDataset(train_dataset, GaussianNoise(std=args.hidden_noise_std)) + noise_label = f"noise std={args.hidden_noise_std}" + else: + aug_dataset = train_dataset + noise_label = "no noise" + + print(f"\nTrain batch ({noise_label}):") + train_loader = DataLoader(aug_dataset, batch_size=1, collate_fn=data_collator) + train_batch = next(iter(train_loader)) + _stats(train_batch["hidden_states"], "hidden_states") + + # Noise/signal summary. + if args.hidden_noise_std > 0.0 and raw_std is not None: + ratio = args.hidden_noise_std / raw_std + print( + f"\nNoise/signal ratio : {ratio:.4f}" + f" (noise_std={args.hidden_noise_std} / eval_hs_std={raw_std:.6f})" + ) + if ratio > 0.2: + print(" WARNING : noise_std >20% of signal — consider reducing --hidden-noise-std") + elif ratio < 0.01: + print(" NOTE : noise_std <1% of signal — augmentation effect may be negligible") + else: + print(" OK : noise level looks reasonable") + + print("\nDone.") + + +if __name__ == "__main__": + main() diff --git a/tools/train_eagle3_offline.py b/tools/train_eagle3_offline.py index a1c34bf5..33146521 100644 --- a/tools/train_eagle3_offline.py +++ b/tools/train_eagle3_offline.py @@ -23,7 +23,9 @@ DatasetManager, DraftModelConfig, Eagle3TrainerFactory, + GaussianNoise, TargetHead, + TransformDataset, create_draft_model, get_supported_chat_template_type_strings, infer_model_params, @@ -157,6 +159,15 @@ def parse_args(): default=False, help="Display data samples during preprocessing (default: False)", ) + data_group.add_argument( + "--hidden_noise_std", + type=float, + default=0.0, + help=( + "Standard deviation of Gaussian noise added to hidden_states during training. " + "0.0 = no augmentation (default). Use tools/test_dataloader.py to calibrate." + ), + ) # Training arguments training_group = parser.add_argument_group("Training Arguments") @@ -365,6 +376,14 @@ def train(): f"{len(offline_eval_dataset) if offline_eval_dataset else 0}" ) + if args.hidden_noise_std > 0.0: + offline_train_dataset = TransformDataset( + offline_train_dataset, GaussianNoise(std=args.hidden_noise_std) + ) + rank0_print( + f"Applying GaussianNoise augmentation to train dataset (std={args.hidden_noise_std})" + ) + # Build vocabulary mapping for draft model from pre-computed vocab mapping rank0_print("Loading vocabulary mapping for draft model...") vocab_mapping_path = os.path.join(args.train_hidden_path, "vocab_mapping.pt")