Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rock/actions/sandbox/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class BashAction(BaseModel):
check: Literal["silent", "raise", "ignore"] = "raise"


Action = BashAction
Action = Annotated[BashAction, Field(discriminator="action_type")]


class WriteFileRequest(BaseModel):
Expand Down
21 changes: 11 additions & 10 deletions rock/admin/entrypoints/sandbox_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
SandboxWriteFileRequest,
)
from rock.admin.proto.response import SandboxStartResponse
from rock.deployments.config import DockerDeploymentConfig
from rock.sandbox.sandbox_manager import SandboxManager
from rock.utils import handle_exceptions

Expand All @@ -37,20 +38,20 @@ def set_sandbox_manager(service: SandboxManager):

@sandbox_router.post("/start")
@handle_exceptions(error_message="start sandbox failed")
async def start(config: SandboxStartRequest) -> RockResponse[SandboxStartResponse]:
sandbox_start_response = await sandbox_manager.start(config.transform())
async def start(request: SandboxStartRequest) -> RockResponse[SandboxStartResponse]:
sandbox_start_response = await sandbox_manager.start(DockerDeploymentConfig.from_request(request))
return RockResponse(result=sandbox_start_response)


@sandbox_router.post("/start_async")
@handle_exceptions(error_message="async start sandbox failed")
async def start_async(
config: SandboxStartRequest,
request: SandboxStartRequest,
x_user_id: str | None = Header(default="default", alias="X-User-Id"),
x_experiment_id: str | None = Header(default="default", alias="X-Experiment-Id"),
) -> RockResponse[SandboxStartResponse]:
sandbox_start_response = await sandbox_manager.start_async(
config.transform(),
DockerDeploymentConfig.from_request(request),
user_info={"user_id": x_user_id, "experiment_id": x_experiment_id},
)
return RockResponse(result=sandbox_start_response)
Expand Down Expand Up @@ -83,37 +84,37 @@ async def get_status(sandbox_id: str):
@sandbox_router.post("/execute")
@handle_exceptions(error_message="execute command failed")
async def execute(command: SandboxCommand) -> RockResponse[CommandResponse]:
return RockResponse(result=await sandbox_manager.execute(command.transform()))
return RockResponse(result=await sandbox_manager.execute(command))


@sandbox_router.post("/create_session")
@handle_exceptions(error_message="create session failed")
async def create_session(request: SandboxCreateBashSessionRequest) -> RockResponse[CreateBashSessionResponse]:
return RockResponse(result=await sandbox_manager.create_session(request.transform()))
return RockResponse(result=await sandbox_manager.create_session(request))


@sandbox_router.post("/run_in_session")
@handle_exceptions(error_message="run in session failed")
async def run(action: SandboxBashAction) -> RockResponse[BashObservation]:
return RockResponse(result=await sandbox_manager.run_in_session(action.transform()))
return RockResponse(result=await sandbox_manager.run_in_session(action))


@sandbox_router.post("/close_session")
@handle_exceptions(error_message="close session failed")
async def close_session(request: SandboxCloseBashSessionRequest) -> RockResponse[CloseBashSessionResponse]:
return RockResponse(result=await sandbox_manager.close_session(request.transform()))
return RockResponse(result=await sandbox_manager.close_session(request))


@sandbox_router.post("/read_file")
@handle_exceptions(error_message="read file failed")
async def read_file(request: SandboxReadFileRequest) -> RockResponse[ReadFileResponse]:
return RockResponse(result=await sandbox_manager.read_file(request.transform()))
return RockResponse(result=await sandbox_manager.read_file(request))


@sandbox_router.post("/write_file")
@handle_exceptions(error_message="write file failed")
async def write_file(request: SandboxWriteFileRequest) -> RockResponse[WriteFileResponse]:
return RockResponse(result=await sandbox_manager.write_file(request.transform()))
return RockResponse(result=await sandbox_manager.write_file(request))


@sandbox_router.post("/upload")
Expand Down
12 changes: 6 additions & 6 deletions rock/admin/entrypoints/sandbox_proxy_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,19 @@ def set_sandbox_proxy_service(service: SandboxProxyService):
@sandbox_proxy_router.post("/execute")
@handle_exceptions(error_message="execute command failed")
async def execute(command: SandboxCommand) -> RockResponse[CommandResponse]:
return RockResponse(result=await sandbox_proxy_service.execute(command.transform()))
return RockResponse(result=await sandbox_proxy_service.execute(command))


@sandbox_proxy_router.post("/create_session")
@handle_exceptions(error_message="create session failed")
async def create_session(request: SandboxCreateBashSessionRequest) -> RockResponse[CreateBashSessionResponse]:
return RockResponse(result=await sandbox_proxy_service.create_session(request.transform()))
return RockResponse(result=await sandbox_proxy_service.create_session(request))


@sandbox_proxy_router.post("/run_in_session")
@handle_exceptions(error_message="run in session failed")
async def run(action: SandboxBashAction) -> RockResponse[BashObservation]:
result = await sandbox_proxy_service.run_in_session(action.transform())
result = await sandbox_proxy_service.run_in_session(action)
if result.exit_code is not None and result.exit_code == -1:
return RockResponse(status=ResponseStatus.FAILED, error=result.failure_reason)
return RockResponse(result=result)
Expand All @@ -58,7 +58,7 @@ async def run(action: SandboxBashAction) -> RockResponse[BashObservation]:
@sandbox_proxy_router.post("/close_session")
@handle_exceptions(error_message="close session failed")
async def close_session(request: SandboxCloseBashSessionRequest) -> RockResponse[CloseBashSessionResponse]:
return RockResponse(result=await sandbox_proxy_service.close_session(request.transform()))
return RockResponse(result=await sandbox_proxy_service.close_session(request))


@sandbox_proxy_router.get("/is_alive")
Expand All @@ -70,13 +70,13 @@ async def is_alive(sandbox_id: str):
@sandbox_proxy_router.post("/read_file")
@handle_exceptions(error_message="read file failed")
async def read_file(request: SandboxReadFileRequest) -> RockResponse[ReadFileResponse]:
return RockResponse(result=await sandbox_proxy_service.read_file(request.transform()))
return RockResponse(result=await sandbox_proxy_service.read_file(request))


@sandbox_proxy_router.post("/write_file")
@handle_exceptions(error_message="write file failed")
async def write_file(request: SandboxWriteFileRequest) -> RockResponse[WriteFileResponse]:
return RockResponse(result=await sandbox_proxy_service.write_file(request.transform()))
return RockResponse(result=await sandbox_proxy_service.write_file(request))


@sandbox_proxy_router.post("/upload")
Expand Down
49 changes: 5 additions & 44 deletions rock/admin/proto/request.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Literal
from typing import Annotated, Literal

from pydantic import BaseModel
from pydantic import BaseModel, Field

from rock import env_vars
from rock.actions import (
Expand All @@ -25,12 +25,6 @@ class SandboxStartRequest(BaseModel):
cpus: float = 2
"""The amount of CPUs to allocate for the container."""

def transform(self):
from rock.deployments.config import DockerDeploymentConfig

res = DockerDeploymentConfig(**self.model_dump())
return res


class SandboxCommand(Command):
timeout: float | None = 1200
Expand All @@ -52,25 +46,14 @@ class SandboxCommand(Command):
sandbox_id: str | None = None
"""The id of the sandbox."""

def transform(self):
from rock.rocklet.proto.request import InternalCommand

res = InternalCommand(**self.model_dump())
res.container_name = self.sandbox_id
return res


class SandboxCreateBashSessionRequest(CreateBashSessionRequest):
startup_timeout: float = 1.0
max_read_size: int = 2000
sandbox_id: str | None = None

def transform(self):
from rock.rocklet.proto.request import InternalCreateBashSessionRequest

res = InternalCreateBashSessionRequest(**self.model_dump())
res.container_name = self.sandbox_id
return res
SandboxCreateSessionRequest = Annotated[SandboxCreateBashSessionRequest, Field(discriminator="session_type")]


class SandboxBashAction(BashAction):
Expand All @@ -91,46 +74,24 @@ class SandboxBashAction(BashAction):
expect: list[str] = []
"""Outputs to expect in addition to the PS1"""

def transform(self):
from rock.rocklet.proto.request import InternalBashAction

res = InternalBashAction(**self.model_dump())
res.container_name = self.sandbox_id
return res
SandboxAction = Annotated[SandboxBashAction, Field(discriminator="action_type")]


class SandboxCloseBashSessionRequest(CloseBashSessionRequest):
sandbox_id: str | None = None

def transform(self):
from rock.rocklet.proto.request import InternalCloseBashSessionRequest

res = InternalCloseBashSessionRequest(**self.model_dump())
res.container_name = self.sandbox_id
return res
SandboxCloseSessionRequest = Annotated[SandboxCloseBashSessionRequest, Field(discriminator="session_type")]


class SandboxReadFileRequest(ReadFileRequest):
sandbox_id: str | None = None

def transform(self):
from rock.rocklet.proto.request import InternalReadFileRequest

res = InternalReadFileRequest(**self.model_dump())
res.container_name = self.sandbox_id
return res


class SandboxWriteFileRequest(WriteFileRequest):
sandbox_id: str | None = None

def transform(self):
from rock.rocklet.proto.request import InternalWriteFileRequest

res = InternalWriteFileRequest(**self.model_dump())
res.container_name = self.sandbox_id
return res


class WarmupRequest(BaseModel):
image: str = "python:3.11"
6 changes: 6 additions & 0 deletions rock/deployments/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from pydantic import BaseModel, ConfigDict, Field, model_validator

from rock.admin.proto.request import SandboxStartRequest
from rock.config import RuntimeConfig
from rock.deployments.abstract import AbstractDeployment
from rock.utils import REQUEST_TIMEOUT_SECONDS
Expand Down Expand Up @@ -149,6 +150,11 @@ def get_deployment(self) -> AbstractDeployment:
def auto_clear_time(self) -> int:
return self.auto_clear_time_minutes

@classmethod
def from_request(cls, request: SandboxStartRequest) -> DeploymentConfig:
"""Create DockerDeploymentConfig from SandboxStartRequest"""
return cls(**request.model_dump())


class RayDeploymentConfig(DockerDeploymentConfig):
"""Configuration for Ray-based distributed deployment."""
Expand Down
12 changes: 6 additions & 6 deletions rock/rocklet/local_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
EnvStepResponse,
UploadResponse,
)
from rock.admin.proto.request import SandboxAction as Action
from rock.admin.proto.request import SandboxCloseSessionRequest as CloseSessionRequest
from rock.admin.proto.request import SandboxCommand as Command
from rock.admin.proto.request import SandboxCreateSessionRequest as CreateSessionRequest
from rock.admin.proto.request import SandboxReadFileRequest as ReadFileRequest
from rock.admin.proto.request import SandboxWriteFileRequest as WriteFileRequest
from rock.rocklet.local_sandbox import LocalSandboxRuntime
from rock.rocklet.proto.request import InternalAction as Action
from rock.rocklet.proto.request import InternalCloseSessionRequest as CloseSessionRequest
from rock.rocklet.proto.request import InternalCommand as Command
from rock.rocklet.proto.request import InternalCreateSessionRequest as CreateSessionRequest
from rock.rocklet.proto.request import InternalReadFileRequest as ReadFileRequest
from rock.rocklet.proto.request import InternalWriteFileRequest as WriteFileRequest
from rock.utils import get_executor

local_router = APIRouter()
Expand Down
61 changes: 12 additions & 49 deletions rock/rocklet/local_sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,21 @@
IsAliveResponse,
LocalSandboxRuntimeConfig,
Observation,
ReadFileRequest,
ReadFileResponse,
UploadRequest,
UploadResponse,
WriteFileRequest,
WriteFileResponse,
)
from rock.admin.proto.request import SandboxAction as Action
from rock.admin.proto.request import SandboxBashAction as BashAction
from rock.admin.proto.request import SandboxCloseSessionRequest as CloseSessionRequest
from rock.admin.proto.request import SandboxCommand as Command
from rock.admin.proto.request import SandboxCreateBashSessionRequest as CreateBashSessionRequest
from rock.admin.proto.request import SandboxCreateSessionRequest as CreateSessionRequest
from rock.admin.proto.request import SandboxReadFileRequest as ReadFileRequest
from rock.admin.proto.request import SandboxWriteFileRequest as WriteFileRequest
from rock.logger import init_logger
from rock.rocklet.exceptions import (
BashIncorrectSyntaxError,
Expand All @@ -48,16 +59,6 @@
SessionExistsError,
SessionNotInitializedError,
)
from rock.rocklet.proto.request import BashInterruptAction
from rock.rocklet.proto.request import InternalAction as Action
from rock.rocklet.proto.request import InternalBashAction as BashAction
from rock.rocklet.proto.request import InternalCloseSessionRequest as CloseSessionRequest
from rock.rocklet.proto.request import InternalCommand as Command
from rock.rocklet.proto.request import InternalCreateBashSessionRequest as CreateBashSessionRequest
from rock.rocklet.proto.request import InternalCreateSessionRequest as CreateSessionRequest
from rock.rocklet.proto.request import InternalReadFileRequest as ReadFileRequest
from rock.rocklet.proto.request import InternalUploadRequest as UploadRequest
from rock.rocklet.proto.request import InternalWriteFileRequest as WriteFileRequest
from rock.utils import get_executor

__all__ = ["LocalSandboxRuntime", "BashSession"]
Expand Down Expand Up @@ -207,43 +208,7 @@ def _eat_following_output(self, timeout: float = 0.5) -> str:
return ""
return _strip_control_chars(output)

async def async_interrupt(self, action: BashInterruptAction) -> BashObservation:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(self._executor, self.interrupt, action)

def interrupt(self, action: BashInterruptAction) -> BashObservation:
"""Interrupt the session."""
output = ""
for _ in range(action.n_retry):
self.shell.sendintr()
expect_strings = action.expect + [self._ps1]
try:
expect_index = self.shell.expect(expect_strings, timeout=action.timeout) # type: ignore
matched_expect_string = expect_strings[expect_index]
except Exception:
time.sleep(0.2)
continue
output += _strip_control_chars(self.shell.before) # type: ignore
output += self._eat_following_output()
output = output.strip()
return BashObservation(output=output, exit_code=0, expect_string=matched_expect_string)
# Fall back to putting job to background and killing it there:
try:
self.shell.sendcontrol("z")
self.shell.expect(expect_strings, timeout=action.timeout)
output += self.shell.before
self.shell.sendline("kill -9 %1")
expect_index = self.shell.expect(expect_strings, timeout=action.timeout) # type: ignore
matched_expect_string = expect_strings[expect_index]
output += self.shell.before
output += self._eat_following_output()
output = output.strip()
return BashObservation(output=output, exit_code=0, expect_string=matched_expect_string)
except pexpect.TIMEOUT:
msg = "Failed to interrupt session"
raise pexpect.TIMEOUT(msg)

async def run(self, action: BashAction | BashInterruptAction) -> BashObservation:
async def run(self, action: BashAction) -> BashObservation:
"""Run a bash action.

Raises:
Expand All @@ -258,8 +223,6 @@ async def run(self, action: BashAction | BashInterruptAction) -> BashObservation
if self.shell is None:
msg = "shell not initialized"
raise SessionNotInitializedError(msg)
if isinstance(action, BashInterruptAction):
return await self.async_interrupt(action)
if action.is_interactive_command or action.is_interactive_quit:
return await self._aync_run_interactive(action)
r = await self._async_run_normal(action)
Expand Down
Loading
Loading