Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 31 additions & 13 deletions vllm_spyre/v1/worker/spyre_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,25 +259,43 @@ def __init__(
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
activities = [torch.profiler.ProfilerActivity.CPU]
logger.info("Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir)

if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn":
from torch_sendnn import torch_sendnn
torch.utils.rename_privateuse1_backend("aiu")
torch._register_device_module("aiu",
torch_sendnn.sendnn_backend)
torch.utils.generate_methods_for_privateuse1_backend()
activities.append(torch.profiler.ProfilerActivity.PrivateUse1)
logger.info("Traces will contain AIU events if PyTorch with"
" AIU profiling support is installed.")
os.environ["ProfilerActivity"] = "PrivateUse1" # noqa: SIM112

# Get the current value of DT_OPT and autopilot
dt_opt = os.environ.get("DT_OPT", "")
options = dict(
opt.split('=') for opt in dt_opt.split(',') if '=' in opt)
autopilot_opt = options.get(
"autopilot", "1") # autopilot defaults to 1 if not set
if autopilot_opt == "1":
logger.warning(
"autopilot on detected with profiling enabled. Add "
"autpilot=0 to DT_OPT to see individual AIU-kernel "
"execution in the trace.")

logger.debug(
"Profiler config: record_shapes=%s,"
"profile_memory=%s,with_stack=%s,with_flops=%s",
envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
envs.VLLM_TORCH_PROFILER_WITH_STACK,
envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
)

self.profiler = torch.profiler.profile(
activities=activities,
record_shapes=True,
with_stack=True,
activities=[torch.profiler.ProfilerActivity.CPU],
record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, use_gzip=True))
print(
"[SpyreWorker] Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir)
else:
self.profiler = None

Expand Down