Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion vllm_spyre/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,30 @@
from logging.config import dictConfig

from vllm.logger import DEFAULT_LOGGING_CONFIG


def register():
"""Register the Spyre platform."""
return "vllm_spyre.platform.SpyrePlatform"
return "vllm_spyre.platform.SpyrePlatform"


def _init_logging():
"""Setup logging, extending from the vLLM logging config"""
config = {**DEFAULT_LOGGING_CONFIG}

# Copy the vLLM logging configurations for our package
config["formatters"]["vllm_spyre"] = DEFAULT_LOGGING_CONFIG["formatters"][
"vllm"]

handler_config = DEFAULT_LOGGING_CONFIG["handlers"]["vllm"]
handler_config["formatter"] = "vllm_spyre"
config["handlers"]["vllm_spyre"] = handler_config

logger_config = DEFAULT_LOGGING_CONFIG["loggers"]["vllm"]
logger_config["handlers"] = ["vllm_spyre"]
config["loggers"]["vllm_spyre"] = logger_config

dictConfig(config)


_init_logging()
30 changes: 17 additions & 13 deletions vllm_spyre/model_executor/model_loader/spyre.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ def load_weights(self, model_config: ModelConfig, max_prompt_length: int,
if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn_decoder":
from aiu_as_addon import aiu_adapter, aiu_linear # noqa: F401
linear_type = "gptq_aiu"
print("Loaded `aiu_as_addon` functionalities")
logger.info("Loaded `aiu_as_addon` functionalities")
else:
from cpu_addon import cpu_linear # noqa: F401
linear_type = "gptq_cpu"
print("Loaded `cpu_addon` functionalities")
logger.info("Loaded `cpu_addon` functionalities")

quant_cfg = model_config._parse_quant_hf_config()

Expand Down Expand Up @@ -185,21 +185,25 @@ def load_weights(self, model_config: ModelConfig, max_prompt_length: int,
_prev = torch._dynamo.config.accumulated_cache_size_limit
torch._dynamo.config.accumulated_cache_size_limit = \
_target_cache_size
print("NOTICE: Adjusting "
"torch._dynamo.config.accumulated_cache_size_limit"
f" from {_prev} to "
f"{torch._dynamo.config.accumulated_cache_size_limit} "
f"to accommodate prompt size of {max_prompt_length} "
f"and decode tokens of {max_decode_length}")
logger.info(
"NOTICE: Adjusting "
"torch._dynamo.config.accumulated_cache_size_limit "
"from %s to %s "
"to accommodate prompt size of %d "
"and decode tokens of %d", _prev,
torch._dynamo.config.accumulated_cache_size_limit,
max_prompt_length, max_decode_length)

if _target_cache_size > torch._dynamo.config.cache_size_limit:
_prev = torch._dynamo.config.cache_size_limit
torch._dynamo.config.cache_size_limit = _target_cache_size
print(
"NOTICE: Adjusting torch._dynamo.config.cache_size_limit from"
f" {_prev} to {torch._dynamo.config.cache_size_limit} to "
f"accommodate prompt size of {max_prompt_length} and "
f"decode tokens of {max_decode_length}")
logger.info(
"NOTICE: Adjusting torch._dynamo.config.cache_size_limit "
"from %s to %s "
"to accommodate prompt size of %d "
"and decode tokens of %d", _prev,
torch._dynamo.config.accumulated_cache_size_limit,
max_prompt_length, max_decode_length)

if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND in BACKEND_LIST:
self.model = torch.compile(
Expand Down
9 changes: 3 additions & 6 deletions vllm_spyre/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,9 @@ def set_warmup_shapes(cls, scheduler_config) -> None:
"The lists in VLLM_SPYRE_WARMUP_PROMPT_LENS and "
"VLLM_SPYRE_WARMUP_NEW_TOKENS must have equal length")

print("[SchedulerConfig] VLLM_SPYRE_WARMUP_PROMPT_LENS =",
wup_prompt_lens)
print("[SchedulerConfig] VLLM_SPYRE_WARMUP_NEW_TOKENS =",
wup_new_tokens)
print("[SchedulerConfig] VLLM_SPYRE_WARMUP_BATCH_SIZES =",
wup_batch_sizes)
logger.info("VLLM_SPYRE_WARMUP_PROMPT_LENS = %s", wup_prompt_lens)
logger.info("VLLM_SPYRE_WARMUP_NEW_TOKENS = %s", wup_new_tokens)
logger.info("VLLM_SPYRE_WARMUP_BATCH_SIZES = %s", wup_batch_sizes)

cls.spyre_warmup_shapes = tuple(
sorted([{
Expand Down
8 changes: 4 additions & 4 deletions vllm_spyre/v1/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,7 @@ def execute_model(
sampling_metadata=model_input.sampling_metadata,
)
t1 = time.time() - t0
print("[spyre_model_runner:execute_model] t_token: %.2fms" %
(t1 * 1000))
logger.debug("t_token: %.2fms", (t1 * 1000))

model_output = ModelRunnerOutput(
req_ids=list(self._req_ids2idx.keys()),
Expand Down Expand Up @@ -400,8 +399,9 @@ def _prepare_pad_input_ids(
for input_ids_i in input_ids_list:
seq_len = input_ids_i.size(0)
if max_len > seq_len:
print(f"[SpyreModelRunner] INFO: Padding request of length "
f"{seq_len} tokens to {max_len} tokens.")
logger.info(
"Padding request of length %d tokens to %d tokens.",
seq_len, max_len)
pads = torch.ones(max_len - seq_len,
dtype=torch.long,
device=input_ids_i.device) * self.pad_token_id
Expand Down
33 changes: 18 additions & 15 deletions vllm_spyre/v1/worker/spyre_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.v1.core.scheduler import SchedulerOutput
Expand All @@ -25,6 +26,8 @@
from vllm_spyre.platform import SpyrePlatform
from vllm_spyre.v1.worker.spyre_model_runner import SpyreModelRunner

logger = init_logger(__name__)


class SpyreWorker(WorkerBaseV1):
"""A worker class that executes the model on a group of Spyre cores.
Expand All @@ -47,9 +50,9 @@ def compile_or_warm_up_model(self) -> None:
s["new_tokens"])
for s in spyre_warmup_shapes])

print(f"[SpyreWorker] Start warming up "
f"{len(wup_new_tokens)} "
f"different prompt/decode/batchsize-shape combinations.")
logger.info(
"Start warming up %d different "
"prompt/decode/batchsize-shape combinations.", len(wup_new_tokens))
all_warmup_start_t = time.time()
for i, (prompt_len, num_decode_tokens, batch_size) in enumerate([
(s["prompt_length"], s["new_tokens"], s["batch_size"])
Expand All @@ -62,20 +65,20 @@ def compile_or_warm_up_model(self) -> None:
"VLLM_SPYRE_WARMUP_NEW_TOKENS must be "
"at least 2 (spyre requirement).")
# warmup individual combination
print(f"[SpyreWorker] Warmup {i+1}/"
f"{len(wup_new_tokens)} "
f"prompt/decode/batchsize-shape combinations...")
print(f"[SpyreWorker] Warming up for prompt length {prompt_len}, "
f"decoding {num_decode_tokens} tokens with batch "
f"size {batch_size}")
logger.info(
"Warmup %d/%d prompt/decode/batchsize-shape "
"combinations...", i + 1, len(wup_new_tokens))
logger.info(
"Warming up for prompt length %d, decoding %d tokens with "
"batch size %d", prompt_len, num_decode_tokens, batch_size)
self._warmup_spyre_fixed_size(prompt_len, num_decode_tokens,
self.restricted_tokens, batch_size)
all_warmup_end_t = time.time()
all_warmup_total_t = all_warmup_end_t - all_warmup_start_t
print(f"[SpyreWorker] All warmups for "
f"{len(wup_new_tokens)} different "
f"prompt/decode/batchsize-shape combinations finished. "
f"Total warmup time {all_warmup_total_t}s.")
logger.info(
"All warmups for %d different prompt/decode/batchsize-shape "
"combinations finished. Total warmup time %.3fs.",
len(wup_new_tokens), all_warmup_total_t)

def check_health(self) -> None:
"""Basic health check (override for device-specific checks)."""
Expand Down Expand Up @@ -192,7 +195,7 @@ def load_model(self):

self.restricted_tokens = restricted_tokens

print("[SpyreWorker] load model...")
logger.info("load model...")
# TODO: check additionally if the Spyre card has enough memory
# for all requested model warmups
# printing env variables for debugging purposes
Expand All @@ -207,7 +210,7 @@ def load_model(self):

load_model_end_t = time.time()
load_model_total_t = load_model_end_t - load_model_start_t
print(f"\tload model took {load_model_total_t}s")
logger.info("load model took %.3fs", load_model_total_t)

def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
special_token_ids, batch_size):
Expand Down