diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index a1487ea952..4b55dde211 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -18,7 +18,7 @@ // "workspaceFolder": "/home/coder/${localWorkspaceFolderBasename}", // "workspaceMount": "source=${localWorkspaceFolder},target=/home/coder/${localWorkspaceFolderBasename},type=bind,consistency=consistent", "mounts": [ - "source=${localEnv:HOME}/.cache/huggingface,target=/huggingface,type=bind", // HF cache + //"source=${localEnv:HOME}/.cache/huggingface,target=/huggingface,type=bind", // HF cache "source=/home/scratch.trt_llm_data/,target=/home/scratch.trt_llm_data/,type=bind,consistency=consistent" ], // Note: sourcing .profile is required since we use a local user and the python interpreter is diff --git a/tensorrt_llm/_torch/models/checkpoints/__init__.py b/tensorrt_llm/_torch/models/checkpoints/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tensorrt_llm/_torch/models/checkpoints/checkpoint_loader_bundle.py b/tensorrt_llm/_torch/models/checkpoints/checkpoint_loader_bundle.py new file mode 100644 index 0000000000..aad68bdc3b --- /dev/null +++ b/tensorrt_llm/_torch/models/checkpoints/checkpoint_loader_bundle.py @@ -0,0 +1,57 @@ +from dataclasses import dataclass, field +from typing import Optional + +from tensorrt_llm import logger +from tensorrt_llm._torch.models.checkpoints.config_loader_interface import \ + ConfigLoaderInterface +from tensorrt_llm._torch.models.checkpoints.weights_loader_interface import \ + WeightsLoaderInterface +from tensorrt_llm._torch.models.checkpoints.weights_mapper_interface import \ + WeightsMapperInterface +from tensorrt_llm.mapping import Mapping + + +def _default_weights_loader() -> WeightsLoaderInterface: + logger.warning( + "CheckpointBundle.weights_loader is None. Will use default HfWeightsLoader" + ) + from tensorrt_llm._torch.models.checkpoints.hf.weights_loader import \ + HfWeightsLoader + return HfWeightsLoader() + + +def _default_weights_mapper_cls() -> type[WeightsMapperInterface]: + logger.warning( + "CheckpointBundle.weights_mapper_cls is None. Will use default HfWeightsMapper" + ) + from tensorrt_llm._torch.models.checkpoints.hf.weights_mapper import \ + HfWeightsMapper + return HfWeightsMapper + + +def _default_config_loader() -> ConfigLoaderInterface: + logger.warning( + "CheckpointBundle.config_loader is None. Will use default HfConfigLoader" + ) + from tensorrt_llm._torch.models.checkpoints.hf.config_loader import \ + HfConfigLoader + return HfConfigLoader() + + +@dataclass(kw_only=True) +class CheckpointLoaderBundle: + weights_loader: WeightsLoaderInterface = field( + default_factory=_default_weights_loader) + weights_mapper_cls: type[WeightsMapperInterface] = field( + default_factory=_default_weights_mapper_cls) + config_loader: ConfigLoaderInterface = field( + default_factory=_default_config_loader) + + def set_weights_loader_mapping(self, mapping: Optional[Mapping] = None): + if mapping is None: + mapping = Mapping() + logger.warning( + "mapping not found, using default Mapping() for HfWeightsLoader fallback." + ) + + self.weights_loader.set_mapping(mapping) diff --git a/tensorrt_llm/_torch/models/checkpoints/config_loader_interface.py b/tensorrt_llm/_torch/models/checkpoints/config_loader_interface.py new file mode 100644 index 0000000000..6fbf20ecd9 --- /dev/null +++ b/tensorrt_llm/_torch/models/checkpoints/config_loader_interface.py @@ -0,0 +1,10 @@ +from abc import ABC, abstractmethod + +from tensorrt_llm._torch.model_config import ModelConfig + + +class ConfigLoaderInterface(ABC): + + @abstractmethod + def load(self, checkpoint_dir: str, **kwargs) -> ModelConfig: + pass diff --git a/tensorrt_llm/_torch/models/checkpoints/hf/__init__.py b/tensorrt_llm/_torch/models/checkpoints/hf/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tensorrt_llm/_torch/models/checkpoints/hf/config_loader.py b/tensorrt_llm/_torch/models/checkpoints/hf/config_loader.py new file mode 100644 index 0000000000..c44721ad80 --- /dev/null +++ b/tensorrt_llm/_torch/models/checkpoints/hf/config_loader.py @@ -0,0 +1,9 @@ +from tensorrt_llm._torch.model_config import ModelConfig +from tensorrt_llm._torch.models.checkpoints.config_loader_interface import \ + ConfigLoaderInterface + + +class HfConfigLoader(ConfigLoaderInterface): + + def load(self, checkpoint_dir: str, **kwargs) -> ModelConfig: + return ModelConfig.from_pretrained(checkpoint_dir, **kwargs) diff --git a/tensorrt_llm/_torch/models/checkpoints/hf/weights_mapper.py b/tensorrt_llm/_torch/models/checkpoints/hf/weights_mapper.py new file mode 100644 index 0000000000..68a8cb172c --- /dev/null +++ b/tensorrt_llm/_torch/models/checkpoints/hf/weights_mapper.py @@ -0,0 +1,95 @@ +from typing import Union + +import torch +from torch import nn + +from tensorrt_llm._torch.model_config import TConfig +from tensorrt_llm._torch.models.modeling_utils import DecoderModelForCausalLM + +from ..weights_mapper_interface import WeightsMapperInterface + + +class HfWeightsMapper(WeightsMapperInterface): + + def __init__(self, model: Union[nn.Module, DecoderModelForCausalLM], + config: TConfig): + super().__init__(model, config) + self._callbacks = [ + self._duplicate_kv_weights, + ] + self.map_weights() + + def map_weights(self) -> None: + self._mapping.update({ + 'qkv_proj': ['q_proj', 'k_proj', 'v_proj'], + 'gate_up_proj': ['gate_proj', 'up_proj'] + }) + + def apply_callbacks(self, module_name: str, + module_names_breakdown: list[str], + weights: dict) -> list[dict]: + module_weights = [] + + for new_name in self._mapping[module_name]: + fw = self.filter_weights( + '.'.join(module_names_breakdown + [new_name]), weights) + for callback in self._callbacks: + fw = callback(new_name, fw) + module_weights.append(fw) + + return module_weights + + def should_skip_module(self, module_name: str) -> bool: + if self._model.config.tie_word_embeddings and module_name.startswith( + "lm_head"): + return True + + # Skip loading weights for embedding and lm_head if LoRA is enabled and has custom values + if hasattr(self._model, "model") and hasattr( + self._model.model, 'has_custom_embed_tokens' + ) and self._model.model.has_custom_embed_tokens and module_name == "model.embed_tokens": + return True + if hasattr( + self._model, 'has_custom_lm_head' + ) and self._model.has_custom_lm_head and module_name == "lm_head": + return True + + return super().should_skip_module(module_name) + + def _duplicate_kv_weights(self, new_name: str, weights: dict): + if new_name in ['k_proj', 'v_proj']: + processed_weights = { + k: + self._duplicate_kv(weight=v[:], + head_dim=self._head_dim, + tensor_parallel_size=self._tp_size) + if k in ["weight", "bias"] else v + for k, v in weights.items() + } + return processed_weights + + return weights + + def _duplicate_kv(self, weight: torch.Tensor, head_dim: int, + tensor_parallel_size: int): + num_kv_heads = weight.shape[0] // head_dim + + if num_kv_heads >= tensor_parallel_size: + assert num_kv_heads % tensor_parallel_size == 0 + return weight + + assert tensor_parallel_size % num_kv_heads == 0 + reps = tensor_parallel_size // num_kv_heads + + # bias + if weight.ndim == 1: + return weight.repeat_interleave(reps) + + # weight + weight = weight.reshape(num_kv_heads, head_dim, + -1)[:, + None, :, :].expand(num_kv_heads, reps, + head_dim, + weight.shape[1]) + return weight.reshape(num_kv_heads * reps * head_dim, + -1).clone().detach() diff --git a/tensorrt_llm/_torch/models/checkpoints/weights_loader_interface.py b/tensorrt_llm/_torch/models/checkpoints/weights_loader_interface.py new file mode 100644 index 0000000000..1f3059949a --- /dev/null +++ b/tensorrt_llm/_torch/models/checkpoints/weights_loader_interface.py @@ -0,0 +1,31 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional + +from tensorrt_llm.mapping import Mapping + + +class WeightsLoaderInterface(ABC): + + def __init__(self, mapping: Optional[Mapping] = None): + """ + Initializes the WeightsLoader. + + Args: + mapping: Mapping object for distributed environments. + """ + self._mapping = mapping + + @abstractmethod + def load_weights(self, checkpoint_dir: str) -> Dict[str, Any]: + """ + Loads weights from a checkpoint directory. + + Args: + checkpoint_dir: A path to the checkpoint directory. + + Returns: + A dictionary where keys are tensor names and values are the tensors. + """ + + def set_mapping(self, mapping: Mapping): + self._mapping = mapping diff --git a/tensorrt_llm/_torch/models/checkpoints/weights_mapper_interface.py b/tensorrt_llm/_torch/models/checkpoints/weights_mapper_interface.py new file mode 100644 index 0000000000..258c340705 --- /dev/null +++ b/tensorrt_llm/_torch/models/checkpoints/weights_mapper_interface.py @@ -0,0 +1,78 @@ +from abc import ABC, abstractmethod +from typing import Callable, List, Union + +from torch import nn + +from tensorrt_llm._torch.model_config import ModelConfig, TConfig +from tensorrt_llm._torch.models.modeling_utils import DecoderModelForCausalLM + + +class WeightsMapperInterface(ABC): + + def __init__(self, model: Union[nn.Module, DecoderModelForCausalLM], + config: TConfig): + self._callbacks: list[Callable] = [] + self._mapping: dict = {} + self._skip_modules = [] + self._model = model + self._config = config + + if not hasattr(model, 'model_config') or not isinstance( + model.model_config, ModelConfig): + raise ValueError("model must have a model_config attribute") + if not hasattr(model, 'config'): + raise ValueError("model must have a config attribute") + + self._tp_size = 1 if model.model_config.mapping.enable_attention_dp else model.model_config.mapping.tp_size + self._head_dim = getattr( + model.config, "head_dim", + model.config.hidden_size // model.config.num_attention_heads) + + @abstractmethod + def map_weights(self) -> None: + """ + Maps weights from a source state dictionary (e.g., Hugging Face) + to a TRT-LLM compatible state dictionary. + """ + + @abstractmethod + def apply_callbacks(self, module_name: str, + module_names_breakdown: list[str], + weights: dict) -> list[dict]: + """ + Applies a series of transformation functions to an internal representation + of weights or to guide the mapping process. The exact behavior might depend + on the implementation (e.g., storing callbacks to be applied later). + + Args: + module_name: The specific module name (e.g., 'qkv_proj', 'gate_up_proj') + module_names_breakdown: List of module path components for building full paths + weights: The weights dictionary to process + """ + + def should_apply_to_module(self, module_name: str) -> bool: + return module_name in self._mapping + + @property + def skip_modules(self) -> List[str]: + return self._skip_modules + + @skip_modules.setter + def skip_modules(self, value: List[str]) -> None: + self._skip_modules = value + + def should_skip_module(self, module_name: str) -> bool: + return any(skip_module in module_name + for skip_module in self._skip_modules) + + def filter_weights(self, prefix: str, weights: dict) -> dict: + result = {} + for k, v in weights.items(): + if k.startswith(prefix): + new_k = k[len(prefix) + 1:] + result[new_k] = v + return result + + @property + def mapping(self) -> dict: + return self._mapping diff --git a/tensorrt_llm/_torch/models/modeling_speculative.py b/tensorrt_llm/_torch/models/modeling_speculative.py index 55591cc354..c11b1d4ca4 100644 --- a/tensorrt_llm/_torch/models/modeling_speculative.py +++ b/tensorrt_llm/_torch/models/modeling_speculative.py @@ -4,6 +4,8 @@ from torch import nn from transformers import LlamaConfig +from tensorrt_llm._torch.models.checkpoints.weights_mapper_interface import \ + WeightsMapperInterface from tensorrt_llm.functional import PositionEmbeddingType from ..attention_backend import AttentionMetadata @@ -264,7 +266,8 @@ def forward( return_context_logits, ) - def load_weights(self, weights: Dict): + def load_weights(self, weights: Dict, + weights_mapper: WeightsMapperInterface): new_weights = {} for k, v in weights.items(): if 'lm_head' not in k: @@ -274,9 +277,11 @@ def load_weights(self, weights: Dict): new_k = k new_weights[new_k] = v if self.load_lm_head_from_target: - super().load_weights(new_weights, skip_modules=['lm_head']) + super().load_weights(new_weights, + weights_mapper, + skip_modules=['lm_head']) else: - super().load_weights(new_weights) + super().load_weights(new_weights, weights_mapper) def load_weights_from_target_model(self, target_model: torch.nn.Module) -> None: @@ -401,9 +406,13 @@ def forward( return logits - def load_weights(self, weights: Dict): - super().load_weights(weights, skip_modules=["draft_model"]) + def load_weights(self, weights: Dict, + weights_mapper: WeightsMapperInterface): + super().load_weights(weights, + weights_mapper, + skip_modules=["draft_model"]) - def load_draft_weights(self, weights: Dict): - self.draft_model.load_weights(weights) + def load_draft_weights(self, weights: Dict, + weights_mapper: WeightsMapperInterface): + self.draft_model.load_weights(weights, weights_mapper) self.draft_model.load_weights_from_target_model(self) diff --git a/tensorrt_llm/_torch/models/modeling_utils.py b/tensorrt_llm/_torch/models/modeling_utils.py index aa8e1cd35d..11cd223f21 100755 --- a/tensorrt_llm/_torch/models/modeling_utils.py +++ b/tensorrt_llm/_torch/models/modeling_utils.py @@ -519,8 +519,11 @@ def forward( return_context_logits, ) - def load_weights(self, weights: Dict, skip_modules: List[str] = []): - _load_weights_impl(self, weights, skip_modules) + def load_weights(self, + weights: Dict, + weights_mapper: "WeightsMapperInterface", + skip_modules: List[str] = []): + _load_weights_impl(self, weights, weights_mapper, skip_modules) def infer_max_seq_len(self) -> int: # Modified from tensorrt_llm/builder.py _init_max_seq_len @@ -633,72 +636,37 @@ def filter_weights(prefix, weights: Dict): def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM], weights: Dict, + weights_mapper: "WeightsMapperInterface", skip_modules: List[str] = [], params_map: Optional[Dict[str, str]] = None): - if not hasattr(model, 'model_config') or not isinstance( - model.model_config, ModelConfig): - raise ValueError("model must have a model_config attribute") - if not hasattr(model, 'config'): - raise ValueError("model must have a config attribute") + # TODO smor- need to move to weights mapper. + # TODO smor- have no idea where this is used. if params_map is not None: weights = rename_weights_with_regex(params_map, weights) logger.info(f"Renamed weights with params_map: {params_map}") - tp_size = 1 if model.model_config.mapping.enable_attention_dp else model.model_config.mapping.tp_size - head_dim = getattr( - model.config, "head_dim", - model.config.hidden_size // model.config.num_attention_heads) - - params_map = { - 'qkv_proj': ['q_proj', 'k_proj', 'v_proj'], - 'gate_up_proj': ['gate_proj', 'up_proj'] - } - for name, module in tqdm(list(model.named_modules()), desc="Loading weights"): if len(module._parameters) > 0: - # skip load weights if module is in skip_modules - if any(skip_module in name for skip_module in skip_modules): - continue - - # skip load weights if tie word embeddings is enabled and layer is lm_head - if model.config.tie_word_embeddings and name.startswith("lm_head"): - continue - - # Skip loading weights for embedding and lm_head if LoRA is enabled and has custom values - if hasattr(model, "model") and hasattr( - model.model, 'has_custom_embed_tokens' - ) and model.model.has_custom_embed_tokens and name == "model.embed_tokens": - continue - if hasattr(model, 'has_custom_lm_head' - ) and model.has_custom_lm_head and name == "lm_head": + if weights_mapper.should_skip_module(name): continue names = name.split('.') + module_names_breakdown, module_name = names[:-1], names[-1] + # TODO smor- customize llama skip clause here! # WAR: better solution is that llama has its own load_weights function. - if names[-1] == 'next_layer_layernorm': + if module_name == 'next_layer_layernorm': continue - if names[-1] in params_map: - module_weights = [] - for new_name in params_map[names[-1]]: - fw = filter_weights('.'.join(names[:-1] + [new_name]), - weights) - if new_name in ['k_proj', 'v_proj']: - fw = { - k: - duplicate_kv_weight(weight=v[:], - head_dim=head_dim, - tensor_parallel_size=tp_size) - if k in ["weight", "bias"] else v - for k, v in fw.items() - } - - module_weights.append(fw) + + if weights_mapper.should_apply_to_module(module_name): + module_weights = weights_mapper.apply_callbacks( + module_name, module_names_breakdown, weights) module.load_weights(weights=module_weights) else: - module_weights = filter_weights(name, weights) + module_weights = weights_mapper.filter_weights(name, weights) if hasattr(module, 'load_weights'): + # TODO smor- add type hints here module.load_weights(weights=[module_weights]) else: for n, p in module._parameters.items(): diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index a34b946d26..b2e659d8b6 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -442,7 +442,6 @@ def create_py_executor_instance( num_lora_modules = model_engine.model.model_config.pretrained_config.num_hidden_layers * \ len(lora_config.lora_target_modules + lora_config.missing_qkv_modules) - # TODO smor- need to figure out how to set these values executor_config.peft_cache_config = trtllm.PeftCacheConfig( num_device_module_layer=max_lora_rank * num_lora_modules * lora_config.max_loras, diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index e2c1477c5e..5140b3de25 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -2,29 +2,26 @@ import contextlib import functools import gc -import glob import inspect import itertools import math -import multiprocessing import os import traceback import weakref from abc import ABC, abstractmethod from collections import defaultdict from contextlib import contextmanager -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Optional, Tuple -import psutil -import safetensors import torch import torch._dynamo.config import tensorrt_llm.bindings.internal.userbuffers as ub +from tensorrt_llm._torch.models.checkpoints.checkpoint_loader_bundle import \ + CheckpointLoaderBundle from tensorrt_llm._torch.pyexecutor.sampler import SampleStateTensors from tensorrt_llm._torch.speculative.mtp import SampleStateTensorsMTP -from tensorrt_llm._utils import (is_trace_enabled, local_mpi_rank, - local_mpi_size, nvtx_range, release_gc, +from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc, torch_dtype_to_str, trace_func) from tensorrt_llm.bindings.executor import GuidedDecodingConfig from tensorrt_llm.logger import logger @@ -137,81 +134,6 @@ def validate_and_set_kv_cache_quant(model_config: ModelConfig, model_config.quant_config.kv_cache_quant_algo = mapped_pyt_quant -def _prefetch_one_file(file_name): - if os.path.exists(file_name): - logger.info(f"Prefetching {file_name} to memory...") - with open(file_name, 'rb') as f: - f.read() - logger.info(f"Finished prefetching {file_name}.") - - -def prefetch_files(file_names: List[str]): - """ - Prefetch safetensors files to memory so that the weight loading will be much faster. - When multiple ranks run in parallel, each rank will prefetch some files. - """ - - # Find out the files to prefetch for the current rank. - # Each rank loads files with indices local_rank, local_rank + local_mpi_size, local_rank + 2*local_mpi_size, etc. - local_file_names = file_names[local_mpi_rank()::local_mpi_size()] - if len(local_file_names) == 0: - return - - max_processes = min(multiprocessing.cpu_count() * 2, 16, - len(local_file_names)) - with multiprocessing.Pool(processes=max_processes) as pool: - pool.map(_prefetch_one_file, local_file_names) - - -def load_weights(checkpoint_dir: str): - weights = {} - weight_files = glob.glob(f"{checkpoint_dir}/*.safetensors") - if weight_files: - # Prefetch the weight files to CPU memory if the size is less than 90% of the available memory. - # This is a heuristic to avoid prefetching files that are too large and causing file cache thrashing. - prefetch_size = sum(os.path.getsize(file) for file in weight_files) - # If the layer number is overridden, it indicates that only a subset of layers are loaded. - # Prefetching all layers is unnecessary. - num_layers = int(os.environ.get("TLLM_OVERRIDE_LAYER_NUM", "0")) - enable_prefetch = prefetch_size < psutil.virtual_memory( - ).available * 0.9 and num_layers == 0 - if enable_prefetch: - logger.info( - f"Prefetching {prefetch_size / (1024**3):.2f}GB checkpoint files." - ) - prefetch_files(weight_files) - for file in weight_files: - logger.info(f"Loading {file}") - part_weights = safetensors.torch.load_file(file) - weights.update(part_weights) - return weights - - weight_files = glob.glob(f"{checkpoint_dir}/*.bin") - if not weight_files: - weight_files = glob.glob(f"{checkpoint_dir}/*.pth") - - if weight_files: - for file in weight_files: - # try mmap first, if failed, turn off mmap - try: - part_weights = torch.load(file, - weights_only=True, - map_location='cpu', - mmap=True) - except Exception: - logger.warning( - f"Failed to load {file} with mmap=True, fallback to mmap=False" - ) - part_weights = torch.load(file, - weights_only=True, - map_location='cpu', - mmap=False) - weights.update(part_weights) - return weights - - raise RuntimeError(f"No weight files found in {checkpoint_dir}.") - - def initialize_dummy_weights( model: torch.nn.Module, low: float = -1e-3, @@ -328,6 +250,7 @@ def __init__( self, model_path: str, pytorch_backend_config: PyTorchConfig, + checkpoint_loader_bundle: CheckpointLoaderBundle, batch_size: int = 8, max_num_tokens: int = 8192, max_seq_len: Optional[int] = None, @@ -364,6 +287,7 @@ def __init__( self.model = self._load_model( model_path, mapping=self.mapping, + checkpoint_loader_bundle=checkpoint_loader_bundle, attn_backend=attn_backend, moe_backend=pytorch_backend_config.moe_backend, load_format=pytorch_backend_config.load_format, @@ -972,13 +896,19 @@ def __del__(self) -> None: def _load_model(self, checkpoint_dir: str, + checkpoint_loader_bundle: CheckpointLoaderBundle, load_format: LoadFormat, max_num_tokens: int, moe_max_num_tokens: Optional[int] = None, moe_load_balancer: Optional[MoeLoadBalancerConfig] = None, lora_config: Optional[LoraConfig] = None, **kwargs): - config = ModelConfig.from_pretrained( + # TODO smor- need to be able to get model config from CUSTOM checkpoints + # possibly need to adjust kwargs or add a new argument + # TODO smor- currently only one config loader- HF. We must try use MCore or something else to make sure it's reasonable + # TODO smor- perhaps other config loader will need different params? how do you allow it here? + + config = checkpoint_loader_bundle.config_loader.load( checkpoint_dir, trust_remote_code=True, enable_min_latency=self.pytorch_backend_config.enable_min_latency, @@ -1028,18 +958,30 @@ def init_meta_tensor(t: torch.Tensor): logger.info( f"Use {rank_model_storage / (1024**3):.2f} GB for model weights." ) - if load_format == LoadFormat.AUTO: if hasattr(model, 'llm_checkpoint_dir'): - weights = load_weights(model.llm_checkpoint_dir) + # TODO smor- this hasn't been tested yet. + print("SMOR, hasn't been tested yet") + from IPython import embed + embed() + weights = checkpoint_loader_bundle.weights_loader.load_weights( + model.llm_checkpoint_dir) else: - weights = load_weights(checkpoint_dir) + weights = checkpoint_loader_bundle.weights_loader.load_weights( + checkpoint_dir) - model.load_weights(weights) + weights_mapper = checkpoint_loader_bundle.weights_mapper_cls( + model, config) + model.load_weights(weights, weights_mapper=weights_mapper) if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights( ): - weights = load_weights(self.spec_config.draft_model_path) + # TODO smor- not verified yet + print("SMOR, hasn't been tested yet") + from IPython import embed + embed() + weights = checkpoint_loader_bundle.weights_loader.load_weights( + self.spec_config.draft_model_path) model.load_draft_weights(weights) elif load_format == LoadFormat.DUMMY: diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 9f23bddd26..f5f108367c 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -218,6 +218,7 @@ def create_py_executor( spec_config=spec_config, guided_decoding_config=executor_config.guided_decoding_config, lora_config=lora_config, + checkpoint_loader_bundle=executor_config.checkpoint_loader_bundle, ) if has_draft_model_engine: @@ -240,6 +241,8 @@ def create_py_executor( attn_runtime_features=attn_runtime_features, dist=dist, spec_config=draft_spec_config, + checkpoint_loader_bundle=executor_config. + checkpoint_loader_bundle, is_draft_model=True, ) draft_model_engine.kv_cache_manager_key = DRAFT_KV_CACHE_MANAGER_KEY diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 4ddf97b665..a506f8d67e 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -689,6 +689,16 @@ def _build_model(self): trt_engine_dir=self._engine_dir, max_input_len=self.args.max_input_len, max_seq_len=max_seq_len) + + if self.args.checkpoint_loader_bundle is None: + from tensorrt_llm._torch.models.checkpoints.checkpoint_loader_bundle import \ + CheckpointLoaderBundle + self.args.checkpoint_loader_bundle = CheckpointLoaderBundle() + + self.args.checkpoint_loader_bundle.set_weights_loader_mapping( + mapping=getattr(self._executor_config, 'mapping', None)) + self._executor_config.checkpoint_loader_bundle = self.args.checkpoint_loader_bundle + self._executor_config.llm_parallel_config = self.args.parallel_config return_logits = (self.args.gather_generation_logits or (self.args.build_config diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 93dda72179..a4c88e01d1 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -974,6 +974,12 @@ class BaseLlmArgs(BaseModel): exclude=True, alias="_mpi_session") + checkpoint_loader_bundle: Optional[object] = Field( + default=None, + description= + "The checkpoint loading bundle to use for this LLM instance.", + json_schema_extra={"type": "Optional[CheckpointBundle]"}) + backend: Optional[str] = Field( default=None, description="The backend to use for this LLM instance.", diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 389e735402..faa7eb245c 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -304,3 +304,24 @@ def test_codellama_fp8_with_bf16_lora() -> None: lora_request=lora_requests) assert len(outputs) == 2 + + +def test_debug(): + from tensorrt_llm._torch.llm import LLM + from tensorrt_llm._torch.models.checkpoints.checkpoint_loader_bundle import \ + CheckpointLoaderBundle + from tensorrt_llm._torch.models.checkpoints.hf.weights_loader import \ + HfWeightsLoader + from tensorrt_llm._torch.models.checkpoints.hf.weights_mapper import \ + HfWeightsMapper + + non_quantized_model_path = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct" + llm = LLM(model=non_quantized_model_path, + checkpoint_loader_bundle=CheckpointLoaderBundle( + weights_loader=HfWeightsLoader(), + weights_mapper_cls=HfWeightsMapper)) + + prompts = ["Hello, how are you?"] + outputs = llm.generate(prompts) + print(f"Model Output is: {outputs[0].outputs[0].text}") + assert 1 == 1