Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 157 additions & 129 deletions vllm_spyre/model_executor/model_loader/spyre.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,29 @@ class SpyreCausalLM(nn.Module):

def __init__(
self,
config: PretrainedConfig,
model_config: ModelConfig,
parallel_config: ParallelConfig,
max_prompt_length: int,
max_decode_length: int,
) -> None:
super().__init__()
self.config = config
self.logits_processor = LogitsProcessor(config.vocab_size,
logits_as_input=True)

self.logits_processor = LogitsProcessor(
model_config.hf_config.vocab_size, logits_as_input=True)
self.sampler = get_sampler()
self.dtype = torch.float16 if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == \
'sendnn_decoder' else torch.float32

# boolean tensor of length batch size with indices:
# True for unfinished sequences and
# False for finished or padded sequences
self.indices = None

# Lazy initialized (FMS Wrapper Model)
self.model: FmsModelWrapper
# FMS Wrapper Model
self.model = FmsModelWrapper(
model_config,
parallel_config,
max_prompt_length,
max_decode_length,
)

# horizontal offset in physical KV cache memory block
self.tkv: int = 0
Expand Down Expand Up @@ -173,8 +180,11 @@ def forward(

return logits

def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
logits = self.logits_processor(None, hidden_states, sampling_metadata)
return logits

Expand All @@ -186,151 +196,68 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, model_config: ModelConfig, max_prompt_length: int,
max_decode_length: int,
distributed_strategy: Optional[str], **kwargs):

if self.dtype is not model_config.dtype:
logger.info(
"Ignoring user-provided dtype=%s and using dtype=%s instead.",
model_config.dtype, self.dtype)

if model_config.quantization == "gptq":

# note, we have to find a better way to package this
# shouldn't it be part of FMS?
sys.path.append("/home/senuser/aiu-fms")

if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn_decoder":
from aiu_as_addon import aiu_adapter, aiu_linear # noqa: F401
linear_type = "gptq_aiu"
print("Loaded `aiu_as_addon` functionalities")
else:
from cpu_addon import cpu_linear # noqa: F401
linear_type = "gptq_cpu"
print("Loaded `cpu_addon` functionalities")

quant_cfg = model_config._parse_quant_hf_config()

linear_config = {
"linear_type": linear_type,
"group_size": quant_cfg['group_size'],
"desc_act": quant_cfg['desc_act'],
}
data_type = None
model_source = "hf_gptq_aiu"
else:
linear_config = {"linear_type": "torch_linear"}
data_type = self.dtype
model_source = "hf"

is_local = os.path.isdir(model_config.model)
model_path = model_config.model
# Get location of model from HF cache.
if not is_local:
model_path = download_weights_from_hf(
model_name_or_path=model_path,
cache_dir=None,
allow_patterns=["*.safetensors", "*.bin", "*.pt"],
revision=model_config.revision)

# we can use fused weights unless running on Spyre
fused_weights = envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND != "sendnn_decoder"

self.model = FmsModelWrapper(self.config)
self.model.model = get_model(architecture="hf_configured",
variant=model_config.model,
model_path=model_path,
source=model_source,
data_type=data_type,
distributed_strategy=distributed_strategy,
group=dist.group.WORLD,
fused_weights=fused_weights,
linear_config=linear_config)

compile_mode = "default"

self.model.model.eval()
torch.set_grad_enabled(False)

_target_cache_size = max(int(max_decode_length * 2),
int(max_prompt_length * 2.5))
if hasattr(torch._dynamo.config, "accumulated_cache_size_limit") and \
_target_cache_size > torch._dynamo.config.\
accumulated_cache_size_limit:
_prev = torch._dynamo.config.accumulated_cache_size_limit
torch._dynamo.config.accumulated_cache_size_limit = \
_target_cache_size
print("NOTICE: Adjusting "
"torch._dynamo.config.accumulated_cache_size_limit"
f" from {_prev} to "
f"{torch._dynamo.config.accumulated_cache_size_limit} "
f"to accommodate prompt size of {max_prompt_length} "
f"and decode tokens of {max_decode_length}")

if _target_cache_size > torch._dynamo.config.cache_size_limit:
_prev = torch._dynamo.config.cache_size_limit
torch._dynamo.config.cache_size_limit = _target_cache_size
print(
"NOTICE: Adjusting torch._dynamo.config.cache_size_limit from"
f" {_prev} to {torch._dynamo.config.cache_size_limit} to "
f"accommodate prompt size of {max_prompt_length} and "
f"decode tokens of {max_decode_length}")

if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND in BACKEND_LIST:
self.model.model = torch.compile(
self.model.model,
mode=compile_mode,
backend=envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND)


def get_spyre_model(model_config: ModelConfig, parallel_config: ParallelConfig,
max_prompt_length, max_decode_length) -> nn.Module:
def get_spyre_model(
model_config: ModelConfig,
parallel_config: ParallelConfig,
max_prompt_length,
max_decode_length,
) -> nn.Module:

# Create a model instance.
model = SpyreCausalLM(model_config.hf_config)

# Load the weights from the cached or downloaded files.
model.load_weights(
model_config,
max_prompt_length=max_prompt_length,
max_decode_length=max_decode_length,
distributed_strategy="tp" if parallel_config.world_size > 1 else None)
model = SpyreCausalLM(model_config, parallel_config, max_prompt_length,
max_decode_length)

return model


class FmsModelWrapper(nn.Module):

def __init__(
self,
config,
model_config: ModelConfig,
parallel_config: ParallelConfig,
max_prompt_length: int,
max_decode_length: int,
) -> None:
super().__init__()

# Lazy initialized actual FMS model
self.config: PretrainedConfig = model_config.hf_config
self.dtype = torch.float16 if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == \
'sendnn_decoder' else torch.float32

# Actual FMS model
self.model: nn.Module

# Load the weights from the cached or downloaded files.
self.load_weights(model_config=model_config,
max_prompt_length=max_prompt_length,
max_decode_length=max_decode_length,
distributed_strategy="tp"
if parallel_config.world_size > 1 else None)

# physical KV cache (fms wrapper/ AIU Spyre)
# lives in SpyreCausalLM only for convenient model access
warmup_shapes = current_platform.get_warmup_shapes()
max_batch = max(shape["batch_size"] for shape in warmup_shapes)
max_prompt_length = max(shape["prompt_length"]
for shape in warmup_shapes)
max_new_tokens = max(shape["new_tokens"] for shape in warmup_shapes)
# Eventually max_model_len = config.max_position_embeddings,
# Eventually max_model_len = self.config.max_position_embeddings,
# but saving some memory here to only allocate the max in practise
max_model_len = max_prompt_length + max_new_tokens

if config.model_type == 'llama':
num_layers = config.num_hidden_layers
num_kv_heads = config.num_key_value_heads
head_dim = config.hidden_size // config.num_attention_heads
elif config.model_type == 'gpt_bigcode':
num_layers = config.n_layer
num_kv_heads = 1 if config.multi_query else config.n_head
head_dim = config.n_embd // config.n_head
if self.config.model_type == 'llama':
num_layers = self.config.num_hidden_layers
num_kv_heads = self.config.num_key_value_heads
head_dim = self.config.hidden_size // \
self.config.num_attention_heads
elif self.config.model_type == 'gpt_bigcode':
num_layers = self.config.n_layer
num_kv_heads = 1 if self.config.multi_query else self.config.n_head
head_dim = self.config.n_embd // self.config.n_head
else:
print(f"[SpyreCausalLM] model type {config.model_type} "
print(f"[SpyreCausalLM] model type {self.config.model_type} "
f"not supported in FMS wrapper")

# (layers)x(k,v)x[max_batch, num_kv_heads, max_model_len, head_dim]
Expand Down Expand Up @@ -415,3 +342,104 @@ def update_sample_inputs(
1).clone().detach().reshape((1, 1))
self.sample_mask = torch.nn.functional.pad(
self.sample_mask[0, -1, :].unsqueeze(0), (0, 1)).unsqueeze(0)

def load_weights(
self,
model_config: ModelConfig,
max_prompt_length: int,
max_decode_length: int,
distributed_strategy: Optional[str],
**kwargs,
) -> None:

if self.dtype is not model_config.dtype:
logger.info(
"Ignoring user-provided dtype=%s and using dtype=%s instead.",
model_config.dtype, self.dtype)

if model_config.quantization == "gptq":

# note, we have to find a better way to package this
# shouldn't it be part of FMS?
sys.path.append("/home/senuser/aiu-fms")

if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn_decoder":
from aiu_as_addon import aiu_adapter, aiu_linear # noqa: F401
linear_type = "gptq_aiu"
print("Loaded `aiu_as_addon` functionalities")
else:
from cpu_addon import cpu_linear # noqa: F401
linear_type = "gptq_cpu"
print("Loaded `cpu_addon` functionalities")

quant_cfg = model_config._parse_quant_hf_config()

linear_config = {
"linear_type": linear_type,
"group_size": quant_cfg['group_size'],
"desc_act": quant_cfg['desc_act'],
}
data_type = None
model_source = "hf_gptq_aiu"
else:
linear_config = {"linear_type": "torch_linear"}
data_type = self.dtype
model_source = "hf"

is_local = os.path.isdir(model_config.model)
model_path = model_config.model
# Get location of model from HF cache.
if not is_local:
model_path = download_weights_from_hf(
model_name_or_path=model_path,
cache_dir=None,
allow_patterns=["*.safetensors", "*.bin", "*.pt"],
revision=model_config.revision)

# we can use fused weights unless running on Spyre
fused_weights = envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND != "sendnn_decoder"

self.model = get_model(architecture="hf_configured",
variant=model_config.model,
model_path=model_path,
source=model_source,
data_type=data_type,
distributed_strategy=distributed_strategy,
group=dist.group.WORLD,
fused_weights=fused_weights,
linear_config=linear_config)

compile_mode = "default"

self.model.eval()
torch.set_grad_enabled(False)

_target_cache_size = max(int(max_decode_length * 2),
int(max_prompt_length * 2.5))
if hasattr(torch._dynamo.config, "accumulated_cache_size_limit") and \
_target_cache_size > torch._dynamo.config.\
accumulated_cache_size_limit:
_prev = torch._dynamo.config.accumulated_cache_size_limit
torch._dynamo.config.accumulated_cache_size_limit = \
_target_cache_size
print("NOTICE: Adjusting "
"torch._dynamo.config.accumulated_cache_size_limit"
f" from {_prev} to "
f"{torch._dynamo.config.accumulated_cache_size_limit} "
f"to accommodate prompt size of {max_prompt_length} "
f"and decode tokens of {max_decode_length}")

if _target_cache_size > torch._dynamo.config.cache_size_limit:
_prev = torch._dynamo.config.cache_size_limit
torch._dynamo.config.cache_size_limit = _target_cache_size
print(
"NOTICE: Adjusting torch._dynamo.config.cache_size_limit from"
f" {_prev} to {torch._dynamo.config.cache_size_limit} to "
f"accommodate prompt size of {max_prompt_length} and "
f"decode tokens of {max_decode_length}")

if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND in BACKEND_LIST:
self.model = torch.compile(
self.model,
mode=compile_mode,
backend=envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND)
2 changes: 1 addition & 1 deletion vllm_spyre/v1/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def pad_input_ids(
# this is a causal mask for generation
mask = (mask.unsqueeze(-1) == mask.unsqueeze(-2)).tril()
mask = torch.where(mask.logical_not(), -torch.inf, 0.0)
mask = mask.to(self.model.dtype)
mask = mask.to(self.model.model.dtype)
position_ids = torch.stack(position_ids_list)

return input_ids, position_ids, mask
Expand Down
2 changes: 1 addition & 1 deletion vllm_spyre/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def pad_input_ids(
# this is a causal mask for generation
mask = (mask.unsqueeze(-1) == mask.unsqueeze(-2)).tril()
mask = torch.where(mask.logical_not(), -torch.inf, 0.0)
mask = mask.to(self.model.dtype)
mask = mask.to(self.model.model.dtype)
position_ids = torch.stack(position_ids_list)

return input_ids, position_ids, mask
Expand Down