Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 61 additions & 18 deletions scripts/patch/patch_sglang.sh
Original file line number Diff line number Diff line change
Expand Up @@ -169,21 +169,40 @@ utils_text = replace_once(
utils_path.write_text(utils_text)

# ---------------------------------------------------------------------------
# tokenizer_manager.py — expose the already-tokenized prompt via meta_info so
# serving_chat can emit input_token_ids without paying per-prompt-token
# logprob compute.
# tokenizer_manager.py — persist canonical prompt token IDs on ReqState right
# after tokenization, then expose them via meta_info so serving_chat can emit
# input_token_ids without paying per-prompt-token logprob compute.
# ---------------------------------------------------------------------------
tokenizer_manager_text = tokenizer_manager_path.read_text()
tokenizer_manager_text = replace_once(
tokenizer_manager_text,
" # Build meta_info and return value\n"
" meta_info = {\n"
" \"id\": rid,\n"
" \"finish_reason\": recv_obj.finished_reasons[i],\n"
" \"prompt_tokens\": recv_obj.prompt_tokens[i],\n"
" \"weight_version\": self.server_args.weight_version,\n"
" \"total_retractions\": recv_obj.retraction_counts[i],\n"
" }\n",
" output_ids: List[int] = dataclasses.field(default_factory=list)\n"
" input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)\n",
" output_ids: List[int] = dataclasses.field(default_factory=list)\n"
" input_token_ids: List[int] = dataclasses.field(default_factory=list)\n"
" input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)\n",
label=str(tokenizer_manager_path),
)
tokenizer_manager_text = replace_once(
tokenizer_manager_text,
" tokenized_obj.time_stats = self.rid_to_state[obj.rid].time_stats\n"
" self.rid_to_state[obj.rid].time_stats.set_tokenize_finish_time()\n"
"\n"
" return tokenized_obj\n",
" state = self.rid_to_state[obj.rid]\n"
" tokenized_obj.time_stats = state.time_stats\n"
" state.time_stats.set_tokenize_finish_time()\n"
" if isinstance(input_ids, list) and input_ids:\n"
" if isinstance(input_ids[0], int):\n"
" state.input_token_ids = list(input_ids)\n"
" elif isinstance(input_ids[0], list) and input_ids[0]:\n"
" state.input_token_ids = list(input_ids[0])\n"
"\n"
" return tokenized_obj\n",
label=str(tokenizer_manager_path),
)

meta_info_base = (
" # Build meta_info and return value\n"
" meta_info = {\n"
" \"id\": rid,\n"
Expand All @@ -192,14 +211,38 @@ tokenizer_manager_text = replace_once(
" \"weight_version\": self.server_args.weight_version,\n"
" \"total_retractions\": recv_obj.retraction_counts[i],\n"
" }\n"
" _obj_input_ids = getattr(state.obj, \"input_ids\", None)\n"
" if isinstance(_obj_input_ids, list) and _obj_input_ids:\n"
" if isinstance(_obj_input_ids[0], int):\n"
" meta_info[\"input_token_ids\"] = list(_obj_input_ids)\n"
" elif isinstance(_obj_input_ids[0], list) and _obj_input_ids[0]:\n"
" meta_info[\"input_token_ids\"] = list(_obj_input_ids[0])\n",
label=str(tokenizer_manager_path),
)
meta_info_old_patch = (
meta_info_base
+ " _obj_input_ids = getattr(state.obj, \"input_ids\", None)\n"
+ " if isinstance(_obj_input_ids, list) and _obj_input_ids:\n"
+ " if isinstance(_obj_input_ids[0], int):\n"
+ " meta_info[\"input_token_ids\"] = list(_obj_input_ids)\n"
+ " elif isinstance(_obj_input_ids[0], list) and _obj_input_ids[0]:\n"
+ " meta_info[\"input_token_ids\"] = list(_obj_input_ids[0])\n"
)
meta_info_new_patch = (
meta_info_base
+ " _state_input_ids = getattr(state, \"input_token_ids\", None)\n"
+ " _obj_input_ids = _state_input_ids or getattr(state.obj, \"input_ids\", None)\n"
+ " if isinstance(_obj_input_ids, list) and _obj_input_ids:\n"
+ " if isinstance(_obj_input_ids[0], int):\n"
+ " meta_info[\"input_token_ids\"] = list(_obj_input_ids)\n"
+ " elif isinstance(_obj_input_ids[0], list) and _obj_input_ids[0]:\n"
+ " meta_info[\"input_token_ids\"] = list(_obj_input_ids[0])\n"
)
if meta_info_new_patch not in tokenizer_manager_text:
if meta_info_old_patch in tokenizer_manager_text:
tokenizer_manager_text = tokenizer_manager_text.replace(
meta_info_old_patch, meta_info_new_patch, 1
)
else:
tokenizer_manager_text = replace_once(
tokenizer_manager_text,
meta_info_base,
meta_info_new_patch,
label=str(tokenizer_manager_path),
)
tokenizer_manager_path.write_text(tokenizer_manager_text)

# ---------------------------------------------------------------------------
Expand Down
47 changes: 41 additions & 6 deletions src/polar/runtime/apptainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def __init__(self, spec: RuntimeSpec, session_id: str, session_dir: Path) -> Non
safe_name = session_id.replace("/", "-")[:30]
self._instance_name = f"polar-{safe_name}-{short_hash}"
self._binary = self._resolve_binary()
self._direct_exec = bool(os.environ.get("POLAR_APPTAINER_DIRECT_EXEC")) or bool(
spec.kwargs.get("direct_exec", False)
)
Comment on lines +29 to +31
self._overlay_dir = self.session_dir / "overlay"

@property
def runtime_id(self) -> str:
Expand All @@ -44,8 +48,18 @@ async def start(self) -> None:
raise RuntimeError("apptainer runtime was already destroyed")
# Use a host-backed overlay directory instead of --writable-tmpfs
# (default tmpfs overlay is only 64 MB, too small for most workloads).
self._overlay_dir = self.session_dir / "overlay"
self._overlay_dir.mkdir(parents=True, exist_ok=True)
if self._direct_exec:
rc, _, stderr = await self._run_local_command(
*self._exec_prefix(),
"true",
capture=True,
)
if rc != 0:
raise RuntimeError(
f"{self._binary} direct exec failed with exit code {rc}: {stderr}"
)
return
args = [self._binary, "instance", "start",
"--overlay", str(self._overlay_dir)]
if self.spec.gpus > 0:
Expand Down Expand Up @@ -76,6 +90,8 @@ async def stop(self) -> None:
if self._destroyed:
return
self._destroyed = True
if self._direct_exec:
return
rc, _, stderr = await self._run_local_command(
self._binary, "instance", "stop", self._instance_name,
timeout=self._STOP_TIMEOUT, capture=True,
Expand Down Expand Up @@ -105,7 +121,7 @@ async def exec(
shell_exports.append(f"export {key}={shlex.quote(str(effective_env[key]))};")
if shell_exports:
wrapped_command = " ".join(shell_exports + [wrapped_command])
args = [self._binary, "exec", f"instance://{self._instance_name}"]
args = self._exec_prefix()
if effective_env:
args.append("env")
args.extend(f"{key}={value}" for key, value in effective_env.items())
Expand All @@ -128,7 +144,7 @@ async def upload_file(self, local_path: str, remote_path: str) -> None:
"bash",
"-c",
f"tar -cf - -C {shlex.quote(source_dir)} {shlex.quote(filename)} | "
f"{self._binary} exec instance://{self._instance_name} "
f"{shlex.join(self._exec_prefix())} "
f"tar -xf - -C {shlex.quote(parent)}",
capture=False,
)
Expand All @@ -147,7 +163,7 @@ async def upload_dir(self, local_path: str, remote_path: str) -> None:
"bash",
"-c",
f"tar -cf - -C {shlex.quote(local_path)} . | "
f"{self._binary} exec instance://{self._instance_name} "
f"{shlex.join(self._exec_prefix())} "
f"tar -xf - -C {shlex.quote(remote_path)}",
capture=False,
)
Expand All @@ -164,7 +180,7 @@ async def download_file(self, remote_path: str, local_path: str) -> None:
rc, _, _ = await self._run_local_command(
"bash",
"-c",
f"{self._binary} exec instance://{self._instance_name} "
f"{shlex.join(self._exec_prefix())} "
f"tar -cf - -C {shlex.quote(parent)} {shlex.quote(filename)} | "
f"tar -xf - -C {shlex.quote(local_dir)}",
capture=False,
Expand All @@ -181,7 +197,7 @@ async def download_dir(self, remote_path: str, local_path: str) -> None:
rc, _, _ = await self._run_local_command(
"bash",
"-c",
f"{self._binary} exec instance://{self._instance_name} "
f"{shlex.join(self._exec_prefix())} "
f"tar -cf - -C {shlex.quote(remote_path)} . | "
f"tar -xf - -C {shlex.quote(local_path)}",
capture=False,
Expand All @@ -191,6 +207,25 @@ async def download_dir(self, remote_path: str, local_path: str) -> None:
f"apptainer download_dir failed with exit code {rc}"
)

def _exec_prefix(self) -> list[str]:
if not self._direct_exec:
return [self._binary, "exec", f"instance://{self._instance_name}"]
args = [self._binary, "exec", "--overlay", str(self._overlay_dir)]
if self.spec.gpus > 0:
args.append("--nv")
network_name: str | None
if not self.spec.allow_internet:
network_name = "none"
else:
network_name = self.spec.network
if network_name and network_name != "host":
args.extend(["--net", "--network", network_name])
args.extend(["--bind", f"{self.session_dir}:{self.runtime_session_dir}"])
for volume in self.spec.kwargs.get("volumes", []):
args.extend(["--bind", str(volume)])
args.append(self.spec.image)
return args

@staticmethod
def _resolve_binary() -> str:
override = os.environ.get("POLAR_APPTAINER_BIN")
Expand Down