diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index fbb885420..f23a024cf 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -7,6 +7,7 @@ import torch import torch.distributed as dist +import vllm.envs as envs from huggingface_hub import hf_hub_download from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, @@ -162,6 +163,31 @@ def __init__( self.spyre_warmup_shapes = SpyrePlatform.get_warmup_shapes( self.vllm_config.scheduler_config) self._env_initialized = False + # Torch profiler. Enabled and configured through env vars: + # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + if envs.VLLM_TORCH_PROFILER_DIR: + torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + activities = [torch.profiler.ProfilerActivity.CPU] + + if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn_decoder": + from torch_sendnn import torch_sendnn + torch.utils.rename_privateuse1_backend("aiu") + torch._register_device_module("aiu", + torch_sendnn.sendnn_backend) + torch.utils.generate_methods_for_privateuse1_backend() + activities.append(torch.profiler.ProfilerActivity.PrivateUse1) + + self.profiler = torch.profiler.profile( + activities=activities, + record_shapes=True, + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir, use_gzip=True)) + print( + "[SpyreWorker] Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir) + else: + self.profiler = None def init_distributed_environment(self) -> None: """Initialize the distributed environment.""" @@ -517,6 +543,14 @@ def _warmup_model_forward_pass( for _ in range(num_decode_tokens - 1): self.execute_model(scheduler_output) + def profile(self, is_start=True): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + if is_start: + self.profiler.start() + else: + self.profiler.stop() + @property def do_metadata_broadcast(self) -> bool: return True diff --git a/vllm_spyre/worker/spyre_worker.py b/vllm_spyre/worker/spyre_worker.py index 2b8fff584..0dd56c349 100644 --- a/vllm_spyre/worker/spyre_worker.py +++ b/vllm_spyre/worker/spyre_worker.py @@ -7,6 +7,7 @@ import torch import torch.distributed as dist +import vllm.envs as envs from huggingface_hub import hf_hub_download from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, @@ -75,6 +76,41 @@ def __init__( self._env_initialized = False self.spyre_warmup_shapes = SpyrePlatform.get_warmup_shapes( self.scheduler_config) + # Torch profiler. Enabled and configured through env vars: + # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + if envs.VLLM_TORCH_PROFILER_DIR: + torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + activities = [torch.profiler.ProfilerActivity.CPU] + + if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn_decoder": + from torch_sendnn import torch_sendnn + torch.utils.rename_privateuse1_backend("aiu") + torch._register_device_module("aiu", + torch_sendnn.sendnn_backend) + torch.utils.generate_methods_for_privateuse1_backend() + activities.append(torch.profiler.ProfilerActivity.PrivateUse1) + + self.profiler = torch.profiler.profile( + activities=activities, + record_shapes=True, + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir, use_gzip=True)) + print( + "[SpyreWorker] Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir) + else: + self.profiler = None + + def start_profile(self): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + self.profiler.start() + + def stop_profile(self): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + self.profiler.stop() def init_distributed_environment(self) -> None: """Initialize the distributed environment."""