Skip to content

Feat/unify checkpoints loading #5372

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
// "workspaceFolder": "/home/coder/${localWorkspaceFolderBasename}",
// "workspaceMount": "source=${localWorkspaceFolder},target=/home/coder/${localWorkspaceFolderBasename},type=bind,consistency=consistent",
"mounts": [
"source=${localEnv:HOME}/.cache/huggingface,target=/huggingface,type=bind", // HF cache
//"source=${localEnv:HOME}/.cache/huggingface,target=/huggingface,type=bind", // HF cache
"source=/home/scratch.trt_llm_data/,target=/home/scratch.trt_llm_data/,type=bind,consistency=consistent"
],
// Note: sourcing .profile is required since we use a local user and the python interpreter is
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from dataclasses import dataclass, field
from typing import Optional

from tensorrt_llm import logger
from tensorrt_llm._torch.models.checkpoints.config_loader_interface import \
ConfigLoaderInterface
from tensorrt_llm._torch.models.checkpoints.weights_loader_interface import \
WeightsLoaderInterface
from tensorrt_llm._torch.models.checkpoints.weights_mapper_interface import \
WeightsMapperInterface
from tensorrt_llm.mapping import Mapping


def _default_weights_loader() -> WeightsLoaderInterface:
logger.warning(
"CheckpointBundle.weights_loader is None. Will use default HfWeightsLoader"
)
from tensorrt_llm._torch.models.checkpoints.hf.weights_loader import \
HfWeightsLoader
return HfWeightsLoader()


def _default_weights_mapper_cls() -> type[WeightsMapperInterface]:
logger.warning(
"CheckpointBundle.weights_mapper_cls is None. Will use default HfWeightsMapper"
)
from tensorrt_llm._torch.models.checkpoints.hf.weights_mapper import \
HfWeightsMapper
return HfWeightsMapper


def _default_config_loader() -> ConfigLoaderInterface:
logger.warning(
"CheckpointBundle.config_loader is None. Will use default HfConfigLoader"
)
from tensorrt_llm._torch.models.checkpoints.hf.config_loader import \
HfConfigLoader
return HfConfigLoader()


@dataclass(kw_only=True)
class CheckpointLoaderBundle:
weights_loader: WeightsLoaderInterface = field(
default_factory=_default_weights_loader)
weights_mapper_cls: type[WeightsMapperInterface] = field(
default_factory=_default_weights_mapper_cls)
config_loader: ConfigLoaderInterface = field(
default_factory=_default_config_loader)

def set_weights_loader_mapping(self, mapping: Optional[Mapping] = None):
if mapping is None:
mapping = Mapping()
logger.warning(
"mapping not found, using default Mapping() for HfWeightsLoader fallback."
)

self.weights_loader.set_mapping(mapping)
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from abc import ABC, abstractmethod

from tensorrt_llm._torch.model_config import ModelConfig


class ConfigLoaderInterface(ABC):

@abstractmethod
def load(self, checkpoint_dir: str, **kwargs) -> ModelConfig:
pass
Empty file.
9 changes: 9 additions & 0 deletions tensorrt_llm/_torch/models/checkpoints/hf/config_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.checkpoints.config_loader_interface import \
ConfigLoaderInterface


class HfConfigLoader(ConfigLoaderInterface):

def load(self, checkpoint_dir: str, **kwargs) -> ModelConfig:
return ModelConfig.from_pretrained(checkpoint_dir, **kwargs)
95 changes: 95 additions & 0 deletions tensorrt_llm/_torch/models/checkpoints/hf/weights_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import Union

import torch
from torch import nn

from tensorrt_llm._torch.model_config import TConfig
from tensorrt_llm._torch.models.modeling_utils import DecoderModelForCausalLM

from ..weights_mapper_interface import WeightsMapperInterface


class HfWeightsMapper(WeightsMapperInterface):

def __init__(self, model: Union[nn.Module, DecoderModelForCausalLM],
config: TConfig):
super().__init__(model, config)
self._callbacks = [
self._duplicate_kv_weights,
]
self.map_weights()

def map_weights(self) -> None:
self._mapping.update({
'qkv_proj': ['q_proj', 'k_proj', 'v_proj'],
'gate_up_proj': ['gate_proj', 'up_proj']
})

def apply_callbacks(self, module_name: str,
module_names_breakdown: list[str],
weights: dict) -> list[dict]:
module_weights = []

for new_name in self._mapping[module_name]:
fw = self.filter_weights(
'.'.join(module_names_breakdown + [new_name]), weights)
for callback in self._callbacks:
fw = callback(new_name, fw)
module_weights.append(fw)

return module_weights

def should_skip_module(self, module_name: str) -> bool:
if self._model.config.tie_word_embeddings and module_name.startswith(
"lm_head"):
return True

# Skip loading weights for embedding and lm_head if LoRA is enabled and has custom values
if hasattr(self._model, "model") and hasattr(
self._model.model, 'has_custom_embed_tokens'
) and self._model.model.has_custom_embed_tokens and module_name == "model.embed_tokens":
return True
if hasattr(
self._model, 'has_custom_lm_head'
) and self._model.has_custom_lm_head and module_name == "lm_head":
return True

return super().should_skip_module(module_name)

def _duplicate_kv_weights(self, new_name: str, weights: dict):
if new_name in ['k_proj', 'v_proj']:
processed_weights = {
k:
self._duplicate_kv(weight=v[:],
head_dim=self._head_dim,
tensor_parallel_size=self._tp_size)
if k in ["weight", "bias"] else v
for k, v in weights.items()
}
return processed_weights

return weights

def _duplicate_kv(self, weight: torch.Tensor, head_dim: int,
tensor_parallel_size: int):
num_kv_heads = weight.shape[0] // head_dim

if num_kv_heads >= tensor_parallel_size:
assert num_kv_heads % tensor_parallel_size == 0
return weight

assert tensor_parallel_size % num_kv_heads == 0
reps = tensor_parallel_size // num_kv_heads

# bias
if weight.ndim == 1:
return weight.repeat_interleave(reps)

# weight
weight = weight.reshape(num_kv_heads, head_dim,
-1)[:,
None, :, :].expand(num_kv_heads, reps,
head_dim,
weight.shape[1])
return weight.reshape(num_kv_heads * reps * head_dim,
-1).clone().detach()
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional

from tensorrt_llm.mapping import Mapping


class WeightsLoaderInterface(ABC):

def __init__(self, mapping: Optional[Mapping] = None):
"""
Initializes the WeightsLoader.

Args:
mapping: Mapping object for distributed environments.
"""
self._mapping = mapping

@abstractmethod
def load_weights(self, checkpoint_dir: str) -> Dict[str, Any]:
"""
Loads weights from a checkpoint directory.

Args:
checkpoint_dir: A path to the checkpoint directory.

Returns:
A dictionary where keys are tensor names and values are the tensors.
"""

def set_mapping(self, mapping: Mapping):
self._mapping = mapping
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from abc import ABC, abstractmethod
from typing import Callable, List, Union

from torch import nn

from tensorrt_llm._torch.model_config import ModelConfig, TConfig
from tensorrt_llm._torch.models.modeling_utils import DecoderModelForCausalLM


class WeightsMapperInterface(ABC):

def __init__(self, model: Union[nn.Module, DecoderModelForCausalLM],
config: TConfig):
self._callbacks: list[Callable] = []
self._mapping: dict = {}
self._skip_modules = []
self._model = model
self._config = config

if not hasattr(model, 'model_config') or not isinstance(
model.model_config, ModelConfig):
raise ValueError("model must have a model_config attribute")
if not hasattr(model, 'config'):
raise ValueError("model must have a config attribute")

self._tp_size = 1 if model.model_config.mapping.enable_attention_dp else model.model_config.mapping.tp_size
self._head_dim = getattr(
model.config, "head_dim",
model.config.hidden_size // model.config.num_attention_heads)

@abstractmethod
def map_weights(self) -> None:
"""
Maps weights from a source state dictionary (e.g., Hugging Face)
to a TRT-LLM compatible state dictionary.
"""

@abstractmethod
def apply_callbacks(self, module_name: str,
module_names_breakdown: list[str],
weights: dict) -> list[dict]:
"""
Applies a series of transformation functions to an internal representation
of weights or to guide the mapping process. The exact behavior might depend
on the implementation (e.g., storing callbacks to be applied later).

Args:
module_name: The specific module name (e.g., 'qkv_proj', 'gate_up_proj')
module_names_breakdown: List of module path components for building full paths
weights: The weights dictionary to process
"""

def should_apply_to_module(self, module_name: str) -> bool:
return module_name in self._mapping

@property
def skip_modules(self) -> List[str]:
return self._skip_modules

@skip_modules.setter
def skip_modules(self, value: List[str]) -> None:
self._skip_modules = value

def should_skip_module(self, module_name: str) -> bool:
return any(skip_module in module_name
for skip_module in self._skip_modules)

def filter_weights(self, prefix: str, weights: dict) -> dict:
result = {}
for k, v in weights.items():
if k.startswith(prefix):
new_k = k[len(prefix) + 1:]
result[new_k] = v
return result

@property
def mapping(self) -> dict:
return self._mapping
23 changes: 16 additions & 7 deletions tensorrt_llm/_torch/models/modeling_speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from torch import nn
from transformers import LlamaConfig

from tensorrt_llm._torch.models.checkpoints.weights_mapper_interface import \
WeightsMapperInterface
from tensorrt_llm.functional import PositionEmbeddingType

from ..attention_backend import AttentionMetadata
Expand Down Expand Up @@ -264,7 +266,8 @@ def forward(
return_context_logits,
)

def load_weights(self, weights: Dict):
def load_weights(self, weights: Dict,
weights_mapper: WeightsMapperInterface):
new_weights = {}
for k, v in weights.items():
if 'lm_head' not in k:
Expand All @@ -274,9 +277,11 @@ def load_weights(self, weights: Dict):
new_k = k
new_weights[new_k] = v
if self.load_lm_head_from_target:
super().load_weights(new_weights, skip_modules=['lm_head'])
super().load_weights(new_weights,
weights_mapper,
skip_modules=['lm_head'])
else:
super().load_weights(new_weights)
super().load_weights(new_weights, weights_mapper)

def load_weights_from_target_model(self,
target_model: torch.nn.Module) -> None:
Expand Down Expand Up @@ -401,9 +406,13 @@ def forward(

return logits

def load_weights(self, weights: Dict):
super().load_weights(weights, skip_modules=["draft_model"])
def load_weights(self, weights: Dict,
weights_mapper: WeightsMapperInterface):
super().load_weights(weights,
weights_mapper,
skip_modules=["draft_model"])

def load_draft_weights(self, weights: Dict):
self.draft_model.load_weights(weights)
def load_draft_weights(self, weights: Dict,
weights_mapper: WeightsMapperInterface):
self.draft_model.load_weights(weights, weights_mapper)
self.draft_model.load_weights_from_target_model(self)
Loading