From 98fce8421a0b9ea6646c08fbaebf5a86b8f9ef93 Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Wed, 19 Mar 2025 18:19:50 +0000 Subject: [PATCH] moving model loading into FMS wrapper Signed-off-by: Yannick Schnider --- .../model_executor/model_loader/spyre.py | 286 ++++++++++-------- vllm_spyre/v1/worker/spyre_model_runner.py | 2 +- vllm_spyre/worker/spyre_model_runner.py | 2 +- 3 files changed, 159 insertions(+), 131 deletions(-) diff --git a/vllm_spyre/model_executor/model_loader/spyre.py b/vllm_spyre/model_executor/model_loader/spyre.py index 161c0b68e..ca7ba0c2d 100644 --- a/vllm_spyre/model_executor/model_loader/spyre.py +++ b/vllm_spyre/model_executor/model_loader/spyre.py @@ -44,22 +44,29 @@ class SpyreCausalLM(nn.Module): def __init__( self, - config: PretrainedConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, + max_prompt_length: int, + max_decode_length: int, ) -> None: super().__init__() - self.config = config - self.logits_processor = LogitsProcessor(config.vocab_size, - logits_as_input=True) + + self.logits_processor = LogitsProcessor( + model_config.hf_config.vocab_size, logits_as_input=True) self.sampler = get_sampler() - self.dtype = torch.float16 if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == \ - 'sendnn_decoder' else torch.float32 + # boolean tensor of length batch size with indices: # True for unfinished sequences and # False for finished or padded sequences self.indices = None - # Lazy initialized (FMS Wrapper Model) - self.model: FmsModelWrapper + # FMS Wrapper Model + self.model = FmsModelWrapper( + model_config, + parallel_config, + max_prompt_length, + max_decode_length, + ) # horizontal offset in physical KV cache memory block self.tkv: int = 0 @@ -173,8 +180,11 @@ def forward( return logits - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: logits = self.logits_processor(None, hidden_states, sampling_metadata) return logits @@ -186,116 +196,18 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, model_config: ModelConfig, max_prompt_length: int, - max_decode_length: int, - distributed_strategy: Optional[str], **kwargs): - - if self.dtype is not model_config.dtype: - logger.info( - "Ignoring user-provided dtype=%s and using dtype=%s instead.", - model_config.dtype, self.dtype) - - if model_config.quantization == "gptq": - - # note, we have to find a better way to package this - # shouldn't it be part of FMS? - sys.path.append("/home/senuser/aiu-fms") - - if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn_decoder": - from aiu_as_addon import aiu_adapter, aiu_linear # noqa: F401 - linear_type = "gptq_aiu" - print("Loaded `aiu_as_addon` functionalities") - else: - from cpu_addon import cpu_linear # noqa: F401 - linear_type = "gptq_cpu" - print("Loaded `cpu_addon` functionalities") - - quant_cfg = model_config._parse_quant_hf_config() - - linear_config = { - "linear_type": linear_type, - "group_size": quant_cfg['group_size'], - "desc_act": quant_cfg['desc_act'], - } - data_type = None - model_source = "hf_gptq_aiu" - else: - linear_config = {"linear_type": "torch_linear"} - data_type = self.dtype - model_source = "hf" - - is_local = os.path.isdir(model_config.model) - model_path = model_config.model - # Get location of model from HF cache. - if not is_local: - model_path = download_weights_from_hf( - model_name_or_path=model_path, - cache_dir=None, - allow_patterns=["*.safetensors", "*.bin", "*.pt"], - revision=model_config.revision) - - # we can use fused weights unless running on Spyre - fused_weights = envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND != "sendnn_decoder" - - self.model = FmsModelWrapper(self.config) - self.model.model = get_model(architecture="hf_configured", - variant=model_config.model, - model_path=model_path, - source=model_source, - data_type=data_type, - distributed_strategy=distributed_strategy, - group=dist.group.WORLD, - fused_weights=fused_weights, - linear_config=linear_config) - - compile_mode = "default" - - self.model.model.eval() - torch.set_grad_enabled(False) - - _target_cache_size = max(int(max_decode_length * 2), - int(max_prompt_length * 2.5)) - if hasattr(torch._dynamo.config, "accumulated_cache_size_limit") and \ - _target_cache_size > torch._dynamo.config.\ - accumulated_cache_size_limit: - _prev = torch._dynamo.config.accumulated_cache_size_limit - torch._dynamo.config.accumulated_cache_size_limit = \ - _target_cache_size - print("NOTICE: Adjusting " - "torch._dynamo.config.accumulated_cache_size_limit" - f" from {_prev} to " - f"{torch._dynamo.config.accumulated_cache_size_limit} " - f"to accommodate prompt size of {max_prompt_length} " - f"and decode tokens of {max_decode_length}") - - if _target_cache_size > torch._dynamo.config.cache_size_limit: - _prev = torch._dynamo.config.cache_size_limit - torch._dynamo.config.cache_size_limit = _target_cache_size - print( - "NOTICE: Adjusting torch._dynamo.config.cache_size_limit from" - f" {_prev} to {torch._dynamo.config.cache_size_limit} to " - f"accommodate prompt size of {max_prompt_length} and " - f"decode tokens of {max_decode_length}") - - if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND in BACKEND_LIST: - self.model.model = torch.compile( - self.model.model, - mode=compile_mode, - backend=envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND) - -def get_spyre_model(model_config: ModelConfig, parallel_config: ParallelConfig, - max_prompt_length, max_decode_length) -> nn.Module: +def get_spyre_model( + model_config: ModelConfig, + parallel_config: ParallelConfig, + max_prompt_length, + max_decode_length, +) -> nn.Module: # Create a model instance. - model = SpyreCausalLM(model_config.hf_config) - - # Load the weights from the cached or downloaded files. - model.load_weights( - model_config, - max_prompt_length=max_prompt_length, - max_decode_length=max_decode_length, - distributed_strategy="tp" if parallel_config.world_size > 1 else None) + model = SpyreCausalLM(model_config, parallel_config, max_prompt_length, + max_decode_length) + return model @@ -303,13 +215,27 @@ class FmsModelWrapper(nn.Module): def __init__( self, - config, + model_config: ModelConfig, + parallel_config: ParallelConfig, + max_prompt_length: int, + max_decode_length: int, ) -> None: super().__init__() - # Lazy initialized actual FMS model + self.config: PretrainedConfig = model_config.hf_config + self.dtype = torch.float16 if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == \ + 'sendnn_decoder' else torch.float32 + + # Actual FMS model self.model: nn.Module + # Load the weights from the cached or downloaded files. + self.load_weights(model_config=model_config, + max_prompt_length=max_prompt_length, + max_decode_length=max_decode_length, + distributed_strategy="tp" + if parallel_config.world_size > 1 else None) + # physical KV cache (fms wrapper/ AIU Spyre) # lives in SpyreCausalLM only for convenient model access warmup_shapes = current_platform.get_warmup_shapes() @@ -317,20 +243,21 @@ def __init__( max_prompt_length = max(shape["prompt_length"] for shape in warmup_shapes) max_new_tokens = max(shape["new_tokens"] for shape in warmup_shapes) - # Eventually max_model_len = config.max_position_embeddings, + # Eventually max_model_len = self.config.max_position_embeddings, # but saving some memory here to only allocate the max in practise max_model_len = max_prompt_length + max_new_tokens - if config.model_type == 'llama': - num_layers = config.num_hidden_layers - num_kv_heads = config.num_key_value_heads - head_dim = config.hidden_size // config.num_attention_heads - elif config.model_type == 'gpt_bigcode': - num_layers = config.n_layer - num_kv_heads = 1 if config.multi_query else config.n_head - head_dim = config.n_embd // config.n_head + if self.config.model_type == 'llama': + num_layers = self.config.num_hidden_layers + num_kv_heads = self.config.num_key_value_heads + head_dim = self.config.hidden_size // \ + self.config.num_attention_heads + elif self.config.model_type == 'gpt_bigcode': + num_layers = self.config.n_layer + num_kv_heads = 1 if self.config.multi_query else self.config.n_head + head_dim = self.config.n_embd // self.config.n_head else: - print(f"[SpyreCausalLM] model type {config.model_type} " + print(f"[SpyreCausalLM] model type {self.config.model_type} " f"not supported in FMS wrapper") # (layers)x(k,v)x[max_batch, num_kv_heads, max_model_len, head_dim] @@ -415,3 +342,104 @@ def update_sample_inputs( 1).clone().detach().reshape((1, 1)) self.sample_mask = torch.nn.functional.pad( self.sample_mask[0, -1, :].unsqueeze(0), (0, 1)).unsqueeze(0) + + def load_weights( + self, + model_config: ModelConfig, + max_prompt_length: int, + max_decode_length: int, + distributed_strategy: Optional[str], + **kwargs, + ) -> None: + + if self.dtype is not model_config.dtype: + logger.info( + "Ignoring user-provided dtype=%s and using dtype=%s instead.", + model_config.dtype, self.dtype) + + if model_config.quantization == "gptq": + + # note, we have to find a better way to package this + # shouldn't it be part of FMS? + sys.path.append("/home/senuser/aiu-fms") + + if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn_decoder": + from aiu_as_addon import aiu_adapter, aiu_linear # noqa: F401 + linear_type = "gptq_aiu" + print("Loaded `aiu_as_addon` functionalities") + else: + from cpu_addon import cpu_linear # noqa: F401 + linear_type = "gptq_cpu" + print("Loaded `cpu_addon` functionalities") + + quant_cfg = model_config._parse_quant_hf_config() + + linear_config = { + "linear_type": linear_type, + "group_size": quant_cfg['group_size'], + "desc_act": quant_cfg['desc_act'], + } + data_type = None + model_source = "hf_gptq_aiu" + else: + linear_config = {"linear_type": "torch_linear"} + data_type = self.dtype + model_source = "hf" + + is_local = os.path.isdir(model_config.model) + model_path = model_config.model + # Get location of model from HF cache. + if not is_local: + model_path = download_weights_from_hf( + model_name_or_path=model_path, + cache_dir=None, + allow_patterns=["*.safetensors", "*.bin", "*.pt"], + revision=model_config.revision) + + # we can use fused weights unless running on Spyre + fused_weights = envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND != "sendnn_decoder" + + self.model = get_model(architecture="hf_configured", + variant=model_config.model, + model_path=model_path, + source=model_source, + data_type=data_type, + distributed_strategy=distributed_strategy, + group=dist.group.WORLD, + fused_weights=fused_weights, + linear_config=linear_config) + + compile_mode = "default" + + self.model.eval() + torch.set_grad_enabled(False) + + _target_cache_size = max(int(max_decode_length * 2), + int(max_prompt_length * 2.5)) + if hasattr(torch._dynamo.config, "accumulated_cache_size_limit") and \ + _target_cache_size > torch._dynamo.config.\ + accumulated_cache_size_limit: + _prev = torch._dynamo.config.accumulated_cache_size_limit + torch._dynamo.config.accumulated_cache_size_limit = \ + _target_cache_size + print("NOTICE: Adjusting " + "torch._dynamo.config.accumulated_cache_size_limit" + f" from {_prev} to " + f"{torch._dynamo.config.accumulated_cache_size_limit} " + f"to accommodate prompt size of {max_prompt_length} " + f"and decode tokens of {max_decode_length}") + + if _target_cache_size > torch._dynamo.config.cache_size_limit: + _prev = torch._dynamo.config.cache_size_limit + torch._dynamo.config.cache_size_limit = _target_cache_size + print( + "NOTICE: Adjusting torch._dynamo.config.cache_size_limit from" + f" {_prev} to {torch._dynamo.config.cache_size_limit} to " + f"accommodate prompt size of {max_prompt_length} and " + f"decode tokens of {max_decode_length}") + + if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND in BACKEND_LIST: + self.model = torch.compile( + self.model, + mode=compile_mode, + backend=envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND) diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index 5c4ee3b66..37821a155 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -438,7 +438,7 @@ def pad_input_ids( # this is a causal mask for generation mask = (mask.unsqueeze(-1) == mask.unsqueeze(-2)).tril() mask = torch.where(mask.logical_not(), -torch.inf, 0.0) - mask = mask.to(self.model.dtype) + mask = mask.to(self.model.model.dtype) position_ids = torch.stack(position_ids_list) return input_ids, position_ids, mask diff --git a/vllm_spyre/worker/spyre_model_runner.py b/vllm_spyre/worker/spyre_model_runner.py index 88f7404b7..dccbdc1e1 100644 --- a/vllm_spyre/worker/spyre_model_runner.py +++ b/vllm_spyre/worker/spyre_model_runner.py @@ -399,7 +399,7 @@ def pad_input_ids( # this is a causal mask for generation mask = (mask.unsqueeze(-1) == mask.unsqueeze(-2)).tril() mask = torch.where(mask.logical_not(), -torch.inf, 0.0) - mask = mask.to(self.model.dtype) + mask = mask.to(self.model.model.dtype) position_ids = torch.stack(position_ids_list) return input_ids, position_ids, mask