Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,18 +143,16 @@ def __init__(
"""Initialize the engine with model and sequence information."""
# NOTE (lucaslie): create a fake Namespace to satisfy PyExecutor requirements...
# This is not correctly declared in the base ModelEngine class though...
self.pytorch_backend_config = SimpleNamespace()
self.pytorch_backend_config.print_iter_log = False
self.pytorch_backend_config.enable_iter_perf_stats = False
self.pytorch_backend_config.enable_iter_req_stats = False
self.pytorch_backend_config.stream_interval = 1
self.pytorch_backend_config.attention_dp_enable_balance = False
self.pytorch_backend_config.attention_dp_time_out_iters = 50
self.pytorch_backend_config.attention_dp_batching_wait_iters = 10
self.pytorch_backend_config.batch_wait_timeout_ms = 0
self.pytorch_backend_config.batch_wait_timeout_iters = 0
self.pytorch_backend_config.batch_wait_max_tokens_ratio = 0.0
self.pytorch_backend_config.max_num_tokens = seq_info.max_num_tokens
self.llm_args = SimpleNamespace()
self.llm_args.print_iter_log = False
self.llm_args.enable_iter_perf_stats = False
self.llm_args.enable_iter_req_stats = False
self.llm_args.stream_interval = 1
self.llm_args.attention_dp_config = None
self.llm_args.batch_wait_timeout_ms = 0
self.llm_args.batch_wait_timeout_iters = 0
self.llm_args.batch_wait_max_tokens_ratio = 0.0
self.llm_args.max_num_tokens = seq_info.max_num_tokens
self.iter_counter = 0

# NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor...
Expand Down
5 changes: 2 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,8 @@ def __init__(self, engine: "PyTorchModelEngine"):
self.engine_ref = weakref.ref(engine)

# High-level configuration
config = engine.pytorch_backend_config
self.enabled = config.use_cuda_graph
self.padding_enabled = config.cuda_graph_padding_enabled
self.enabled = engine.llm_args.cuda_graph_config is not None
self.padding_enabled = engine._cuda_graph_padding_enabled
self.supported_batch_sizes = engine._cuda_graph_batch_sizes
self.max_supported_batch_size = engine._max_cuda_graph_batch_size
self.max_beam_width = engine.max_beam_width
Expand Down
80 changes: 47 additions & 33 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
MultimodalRuntimeData)
from tensorrt_llm.inputs.registry import (create_input_processor,
create_input_processor_with_hash)
from tensorrt_llm.llmapi.llm_args import TorchLlmArgs
from tensorrt_llm.llmapi.llm_args import (CudaGraphConfig, TorchCompileConfig,
TorchLlmArgs)
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.lora_manager import LoraModelConfig
Expand Down Expand Up @@ -53,7 +54,7 @@
from ..utils import (get_model_extra_attrs,
set_per_request_piecewise_cuda_graph_flag,
set_torch_compiling, with_model_extra_attrs)
from .config import PyTorchConfig, _construct_checkpoint_loader
from .config import _construct_checkpoint_loader
from .config_utils import is_mla
from .cuda_graph_runner import CUDAGraphRunner
from .guided_decoder import CapturableGuidedDecoder
Expand Down Expand Up @@ -131,7 +132,7 @@ def __init__(
self,
*,
model_path: str,
pytorch_backend_config: PyTorchConfig,
llm_args: TorchLlmArgs,
mapping: Optional[Mapping] = None,
attn_runtime_features: Optional[AttentionRuntimeFeatures] = None,
dist: Optional[MPIDist] = None,
Expand All @@ -140,10 +141,7 @@ def __init__(
drafting_loop_wrapper: Optional[Callable[[torch.nn.Module],
torch.nn.Module]] = None,
model: Optional[torch.nn.Module] = None,
llm_args: Optional[TorchLlmArgs] = None,
):
assert llm_args is not None, "llm_args must be provided for PyTorchModelEngine"

self.forward_pass_callable = None
self.ub_buffers = None
(
Expand All @@ -168,7 +166,7 @@ def __init__(
self.dist = dist
if dist is not None:
ExpertStatistic.create(self.dist.rank)
self.pytorch_backend_config = pytorch_backend_config
self.llm_args = llm_args
self.original_max_draft_len = spec_config.max_draft_len if spec_config is not None else 0
self.original_max_total_draft_tokens = spec_config.max_total_draft_tokens if spec_config is not None else 0

Expand All @@ -192,7 +190,7 @@ def __init__(
lora_config: Optional[
LoraConfig] = None if is_draft_model else llm_args.lora_config
loader = ModelLoader(
pytorch_backend_config=pytorch_backend_config,
llm_args=llm_args,
mapping=self.mapping,
spec_config=self.spec_config,
sparse_attention_config=self.sparse_attention_config,
Expand All @@ -215,7 +213,7 @@ def __init__(
# In case that some tests use stub models and override `_load_model`.
if not hasattr(self.model, 'extra_attrs'):
self.model.extra_attrs = {}
if self.pytorch_backend_config.enable_layerwise_nvtx_marker:
if self.llm_args.enable_layerwise_nvtx_marker:
layerwise_nvtx_marker = LayerwiseNvtxMarker()
module_prefix = 'Model'
if self.model.model_config and self.model.model_config.pretrained_config and self.model.model_config.pretrained_config.architectures:
Expand All @@ -224,19 +222,39 @@ def __init__(
layerwise_nvtx_marker.register_hooks(self.model, module_prefix)

self.enable_attention_dp = self.model.model_config.mapping.enable_attention_dp
self._disable_overlap_scheduler = self.pytorch_backend_config.disable_overlap_scheduler
self._disable_overlap_scheduler = self.llm_args.disable_overlap_scheduler
self._torch_compile_backend = None
self.dtype = self.model.config.torch_dtype
self._init_model_capacity()

self._torch_compile_backend = None
self.cuda_graph_config = self.llm_args.cuda_graph_config
cuda_graph_batch_sizes = self.cuda_graph_config.batch_sizes if self.cuda_graph_config else CudaGraphConfig.model_fields[
'batch_sizes'].default
cuda_graph_padding_enabled = self.cuda_graph_config.enable_padding if self.cuda_graph_config else CudaGraphConfig.model_fields[
'enable_padding'].default

self.torch_compile_config = self.llm_args.torch_compile_config
torch_compile_enabled = bool(self.torch_compile_config is not None)
torch_compile_fullgraph = self.torch_compile_config.enable_fullgraph if self.torch_compile_config is not None else TorchCompileConfig.model_fields[
'enable_fullgraph'].default
torch_compile_inductor_enabled = self.torch_compile_config.enable_inductor if self.torch_compile_config is not None else TorchCompileConfig.model_fields[
'enable_inductor'].default
torch_compile_piecewise_cuda_graph = self.torch_compile_config.enable_piecewise_cuda_graph if self.torch_compile_config is not None else TorchCompileConfig.model_fields[
'enable_piecewise_cuda_graph'].default
torch_compile_piecewise_cuda_graph_num_tokens = self.torch_compile_config.capture_num_tokens if self.torch_compile_config is not None else TorchCompileConfig.model_fields[
'capture_num_tokens'].default
torch_compile_enable_userbuffers = self.torch_compile_config.enable_userbuffers if self.torch_compile_config is not None else TorchCompileConfig.model_fields[
'enable_userbuffers'].default
torch_compile_max_num_streams = self.torch_compile_config.max_num_streams if self.torch_compile_config is not None else TorchCompileConfig.model_fields[
'max_num_streams'].default

# Eagle3 draft model now does not support torch.compile
self._torch_compile_enabled = pytorch_backend_config.torch_compile_enabled and not is_draft_model
self._torch_compile_piecewise_cuda_graph = pytorch_backend_config.torch_compile_piecewise_cuda_graph
self._torch_compile_enabled = torch_compile_enabled
self._torch_compile_piecewise_cuda_graph = torch_compile_piecewise_cuda_graph

piecewise_cuda_graph_num_tokens = (
pytorch_backend_config.torch_compile_piecewise_cuda_graph_num_tokens
or pytorch_backend_config.cuda_graph_batch_sizes or [])
torch_compile_piecewise_cuda_graph_num_tokens
or cuda_graph_batch_sizes or [])

self._piecewise_cuda_graph_num_tokens = [
i for i in piecewise_cuda_graph_num_tokens
Expand All @@ -245,33 +263,30 @@ def __init__(

try:
use_ub_for_nccl = (
pytorch_backend_config.allreduce_strategy == "NCCL_SYMMETRIC"
self.llm_args.allreduce_strategy == "NCCL_SYMMETRIC"
and self._init_userbuffers(self.model.config.hidden_size))
if self._torch_compile_enabled:
set_torch_compiling(True)
use_ub = not use_ub_for_nccl and (
pytorch_backend_config.torch_compile_enable_userbuffers
torch_compile_enable_userbuffers
and self._init_userbuffers(self.model.config.hidden_size))
self._torch_compile_backend = Backend(
pytorch_backend_config.torch_compile_inductor_enabled,
torch_compile_inductor_enabled,
enable_userbuffers=use_ub,
enable_piecewise_cuda_graph=self.
_torch_compile_piecewise_cuda_graph,
capture_num_tokens=self._piecewise_cuda_graph_num_tokens,
max_num_streams=pytorch_backend_config.
torch_compile_max_num_streams)
max_num_streams=torch_compile_max_num_streams)
if isinstance(self.model, DecoderModelForCausalLM):
self.model.model = torch.compile(
self.model.model,
backend=self._torch_compile_backend,
fullgraph=pytorch_backend_config.torch_compile_fullgraph
)
fullgraph=torch_compile_fullgraph)
else:
self.model = torch.compile(
self.model,
backend=self._torch_compile_backend,
fullgraph=pytorch_backend_config.torch_compile_fullgraph
)
fullgraph=torch_compile_fullgraph)
torch._dynamo.config.cache_size_limit = 16
else:
set_torch_compiling(False)
Expand All @@ -283,7 +298,7 @@ def __init__(
self.is_warmup = False

self.attn_backend = get_attention_backend(
pytorch_backend_config.attn_backend,
self.llm_args.attn_backend,
sparse_attn_config=self.sparse_attention_config)

if self.is_spec_decode:
Expand Down Expand Up @@ -329,13 +344,12 @@ def __init__(
self.iter_states = {}
self._cuda_graph_mem_pool = self._torch_compile_backend._graph_pool_handle if self._torch_compile_enabled else None

self._cuda_graph_padding_enabled = pytorch_backend_config.cuda_graph_padding_enabled
self._cuda_graph_padding_enabled = cuda_graph_padding_enabled

self._cuda_graph_batch_sizes = _filter_cuda_graph_batch_sizes(
pytorch_backend_config.cuda_graph_batch_sizes, self.batch_size,
self.max_num_tokens, self.original_max_total_draft_tokens,
self._cuda_graph_padding_enabled
) if pytorch_backend_config.cuda_graph_batch_sizes else []
cuda_graph_batch_sizes, self.batch_size, self.max_num_tokens,
self.original_max_total_draft_tokens,
self._cuda_graph_padding_enabled) if cuda_graph_batch_sizes else []

self._max_cuda_graph_batch_size = (self._cuda_graph_batch_sizes[-1] if
self._cuda_graph_batch_sizes else 0)
Expand Down Expand Up @@ -554,7 +568,7 @@ def _run_torch_compile_warmup(self, resource_manager: ResourceManager):

def _run_autotuner_warmup(self, resource_manager: ResourceManager):
"""Runs a forward pass to populate the autotuner cache."""
if not self.pytorch_backend_config.enable_autotuner:
if not self.llm_args.enable_autotuner:
return

logger.info("Running autotuner warmup...")
Expand Down Expand Up @@ -2299,7 +2313,7 @@ def forward(

with MoeLoadBalancerIterContext(moe_load_balancer):
# Special handling for multimodal encoder only mode
if self.pytorch_backend_config.mm_encoder_only:
if self.llm_args.mm_encoder_only:
return self._forward_step_mm_encoder_only(
inputs, scheduled_requests)
else:
Expand Down Expand Up @@ -2463,7 +2477,7 @@ def _init_userbuffers(self, hidden_size):
# Disable UB for unsupported platforms
if not ub.ub_supported():
return False
use_nccl_symmetric = self.pytorch_backend_config.allreduce_strategy == "NCCL_SYMMETRIC"
use_nccl_symmetric = self.llm_args.allreduce_strategy == "NCCL_SYMMETRIC"
ub.initialize_userbuffers_manager(
self.mapping.tp_size, self.mapping.pp_size, self.mapping.cp_size,
self.mapping.rank, self.mapping.gpus_per_node,
Expand Down
42 changes: 21 additions & 21 deletions tensorrt_llm/_torch/pyexecutor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch

from tensorrt_llm._utils import str_dtype_to_torch
from tensorrt_llm.llmapi.llm_args import TorchLlmArgs
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.mapping import Mapping
Expand All @@ -19,7 +20,7 @@
from ..models.modeling_utils import MetaInitMode, timing
from ..modules.fused_moe.moe_load_balancer import (
MoeLoadBalancer, maybe_create_moe_load_balancer)
from .config import LoadFormat, PyTorchConfig
from .config import LoadFormat

_KV_CACHE_MAP = {
"fp8": QuantAlgo.FP8.value,
Expand Down Expand Up @@ -157,7 +158,7 @@ class ModelLoader:
"""

def __init__(self,
pytorch_backend_config: PyTorchConfig,
llm_args: TorchLlmArgs,
mapping: Mapping,
spec_config: Optional["DecodingBaseConfig"],
sparse_attention_config: Optional["SparseAttentionConfig"],
Expand All @@ -168,14 +169,14 @@ def __init__(self,
Initializes the ModelLoader.

Args:
pytorch_backend_config: Configuration for the PyTorch backend.
llm_args: Configuration for the PyTorch backend.
mapping: The distributed mapping configuration.
spec_config: Configuration for speculative decoding.
max_num_tokens: The maximum number of tokens the engine will handle.
max_seq_len: The maximum sequence length.
lora_config: Configuration for LoRA.
"""
self.pytorch_backend_config = pytorch_backend_config
self.llm_args = llm_args
self.mapping = mapping
self.spec_config = spec_config
self.sparse_attention_config = sparse_attention_config
Expand All @@ -200,7 +201,7 @@ def load(
"""
config = self._load_and_validate_config(checkpoint_dir,
checkpoint_loader)
load_format = self.pytorch_backend_config.load_format
load_format = self.llm_args.load_format

with timing("Model init total"), maybe_create_moe_load_balancer(
config, self.mapping) as moe_load_balancer:
Expand Down Expand Up @@ -291,30 +292,29 @@ def _load_and_validate_config(
checkpoint_dir,
trust_remote_code=True,
mapping=self.mapping,
enable_min_latency=self.pytorch_backend_config.enable_min_latency,
use_cuda_graph=self.pytorch_backend_config.use_cuda_graph,
force_dynamic_quantization=self.pytorch_backend_config.
force_dynamic_quantization,
enable_min_latency=self.llm_args.enable_min_latency,
use_cuda_graph=self.llm_args.cuda_graph_config is not None,
force_dynamic_quantization=self.llm_args.force_dynamic_quantization,
spec_config=self.spec_config,
sparse_attention_config=self.sparse_attention_config,
max_num_tokens=self.max_num_tokens,
max_seq_len=self.max_seq_len,
moe_max_num_tokens=self.pytorch_backend_config.moe_max_num_tokens,
moe_load_balancer=self.pytorch_backend_config.moe_load_balancer,
moe_max_num_tokens=self.llm_args.moe_config.max_num_tokens,
moe_load_balancer=self.llm_args.moe_config.load_balancer,
lora_config=self.lora_config,
allreduce_strategy=self.pytorch_backend_config.allreduce_strategy,
mm_encoder_only=self.pytorch_backend_config.mm_encoder_only,
attn_backend=self.pytorch_backend_config.attn_backend,
moe_backend=self.pytorch_backend_config.moe_backend,
moe_disable_finalize_fusion=self.pytorch_backend_config.
moe_disable_finalize_fusion,
use_low_precision_moe_combine=self.pytorch_backend_config.
allreduce_strategy=self.llm_args.allreduce_strategy,
mm_encoder_only=self.llm_args.mm_encoder_only,
attn_backend=self.llm_args.attn_backend,
moe_backend=self.llm_args.moe_config.backend,
moe_disable_finalize_fusion=self.llm_args.moe_config.
disable_finalize_fusion,
use_low_precision_moe_combine=self.llm_args.moe_config.
use_low_precision_moe_combine)

validate_and_set_kv_cache_quant(
config, self.pytorch_backend_config.kv_cache_dtype)
validate_and_set_kv_cache_quant(config,
self.llm_args.kv_cache_config.dtype)
validate_and_set_mamba_ssm_cache_dtype(
config, self.pytorch_backend_config.mamba_ssm_cache_dtype)
config, self.llm_args.kv_cache_config.mamba_ssm_cache_dtype)

# Allow overriding the number of layers via environment variable
num_layers_override = int(os.environ.get("TLLM_OVERRIDE_LAYER_NUM",
Expand Down
Loading