Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions angelslim/compressor/speculative/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
DatasetManager,
DraftModelConfig,
Eagle3TrainerFactory,
GaussianNoise,
TargetHead,
TransformDataset,
convert_sharegpt_data,
convert_ultrachat_data,
create_draft_model,
Expand All @@ -39,6 +41,8 @@
"convert_sharegpt_data",
"convert_ultrachat_data",
"DatasetManager",
"GaussianNoise",
"TransformDataset",
"get_supported_chat_template_type_strings",
"TargetHead",
"infer_model_params",
Expand Down
4 changes: 4 additions & 0 deletions angelslim/compressor/speculative/train/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .data import (
DatasetManager,
GaussianNoise,
TransformDataset,
convert_sharegpt_data,
convert_ultrachat_data,
data_generation_work_flow,
Expand All @@ -23,6 +25,8 @@
"convert_sharegpt_data",
"convert_ultrachat_data",
"DatasetManager",
"GaussianNoise",
"TransformDataset",
"get_supported_chat_template_type_strings",
"TargetHead",
"infer_model_params",
Expand Down
3 changes: 3 additions & 0 deletions angelslim/compressor/speculative/train/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
from .data_generation import data_generation_work_flow
from .data_utils import convert_sharegpt_data, convert_ultrachat_data
from .dataset import DatasetManager
from .noise_transforms import GaussianNoise, TransformDataset

__all__ = [
"DatasetManager",
"GaussianNoise",
"TransformDataset",
"convert_sharegpt_data",
"convert_ultrachat_data",
"data_generation_work_flow",
Expand Down
86 changes: 86 additions & 0 deletions angelslim/compressor/speculative/train/data/noise_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2025 Tencent Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Hidden-state transform utilities for offline EAGLE3 training.

Usage in a training script::

from angelslim.compressor.speculative.train.data.noise_transforms import (
GaussianNoise,
TransformDataset,
)

train_dataset = TransformDataset(base_train_dataset, GaussianNoise(std=0.05))
# eval_dataset is passed through unchanged

Calibrating noise std
---------------------
To verify the noise/signal ratio before committing to a full run, use
tools/test_dataloader.py::

python3 tools/test_dataloader.py \\
--train-hidden-path $TRAIN_HIDDEN_PATH \\
--eval-hidden-path $EVAL_HIDDEN_PATH \\
--target-model $TARGET_MODEL_NAME_OR_PATH \\
--hidden-noise-std 0.05
"""

import torch
from torch.utils.data import Dataset


class TransformDataset(Dataset):
"""Wraps an offline dataset and applies a callable transform at ``__getitem__`` time.

The transform is applied to a shallow copy of each sample so the underlying
cache is not mutated when ``cache_in_memory=True`` is used.

Args:
dataset: The underlying offline dataset (e.g. OfflineVLMEagle3Dataset).
transform: Callable that takes a sample dict and returns a modified dict.
"""

def __init__(self, dataset: Dataset, transform) -> None:
self.dataset = dataset
self.transform = transform

def __len__(self) -> int:
return len(self.dataset)

def __getitem__(self, idx: int) -> dict:
# Shallow copy prevents mutating the underlying cache when the inner
# dataset uses cache_in_memory=True.
sample = dict(self.dataset[idx])
return self.transform(sample)


class GaussianNoise:
"""Adds Gaussian noise to a tensor field of a sample dict.

Noise is sampled fresh each call, so the model sees a different perturbation
of every sample on every epoch — effectively multiplying dataset diversity
without generating new hidden states.

Args:
std: Standard deviation of the Gaussian noise. Set to 0.0 to disable.
field: Key of the tensor field to perturb (default: ``"hidden_states"``).
"""

def __init__(self, std: float, field: str = "hidden_states") -> None:
self.std = std
self.field = field

def __call__(self, sample: dict) -> dict:
sample[self.field] = sample[self.field] + torch.randn_like(sample[self.field]) * self.std
return sample
2 changes: 2 additions & 0 deletions scripts/speculative/hunyuan_ocr/train_eagle3_vlm_offline.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ EVAL_HIDDEN_PATH=
OUTPUT_DIR=
RUN_NAME=hunyuan-ocr-eagle3-angelslim
MODEL_MAX_LENGTH=8192
HIDDEN_NOISE_STD=0.0
export MAX_PIXELS=

torchrun --nproc_per_node=8 tools/train_eagle3_offline.py \
Expand Down Expand Up @@ -36,4 +37,5 @@ torchrun --nproc_per_node=8 tools/train_eagle3_offline.py \
--run_name $RUN_NAME \
--num_proc 8 \
--training_time_test_length 4 \
--hidden_noise_std $HIDDEN_NOISE_STD \
--bf16
2 changes: 2 additions & 0 deletions scripts/speculative/qwen3_vl/train_eagle3_vlm_offline.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ EVAL_HIDDEN_PATH=
OUTPUT_DIR=
RUN_NAME=qwen3-4b-eagle3-angelslim
MODEL_MAX_LENGTH=8192
HIDDEN_NOISE_STD=0.0
export MAX_PIXELS=

torchrun --nproc_per_node=8 tools/train_eagle3_offline.py \
Expand Down Expand Up @@ -36,4 +37,5 @@ torchrun --nproc_per_node=8 tools/train_eagle3_offline.py \
--run_name $RUN_NAME \
--num_proc 8 \
--training_time_test_length 3 \
--hidden_noise_std $HIDDEN_NOISE_STD \
--bf16
2 changes: 2 additions & 0 deletions scripts/speculative/train_eagle3_offline.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ export RUN_NAME=
export MODEL_MAX_LENGTH=4096
export LM_HEAD_KEY=
export CHAT_TEMPLATE_TYPE=qwen3
export HIDDEN_NOISE_STD=0.0

torchrun --nproc_per_node=8 tools/train_eagle3_offline.py \
--target_model_name_or_path $TARGET_MODEL_NAME_OR_PATH \
Expand All @@ -37,4 +38,5 @@ torchrun --nproc_per_node=8 tools/train_eagle3_offline.py \
--report_to wandb \
--run_name $RUN_NAME \
--num_proc 48 \
--hidden_noise_std $HIDDEN_NOISE_STD \
--bf16
2 changes: 2 additions & 0 deletions scripts/speculative/train_eagle3_vlm_offline.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ export RUN_NAME=
export MODEL_MAX_LENGTH=4096
export LM_HEAD_KEY=
export CHAT_TEMPLATE_TYPE=qwen3_vl
export HIDDEN_NOISE_STD=0.0

torchrun --nproc_per_node=8 tools/train_eagle3_offline.py \
--modal_type VLM \
Expand Down Expand Up @@ -38,4 +39,5 @@ torchrun --nproc_per_node=8 tools/train_eagle3_offline.py \
--report_to wandb \
--run_name $RUN_NAME \
--num_proc 48 \
--hidden_noise_std $HIDDEN_NOISE_STD \
--bf16
Loading
Loading