diff --git a/rock/actions/sandbox/request.py b/rock/actions/sandbox/request.py index 3ad9b696..bf0ae8f1 100644 --- a/rock/actions/sandbox/request.py +++ b/rock/actions/sandbox/request.py @@ -28,7 +28,7 @@ class BashAction(BaseModel): check: Literal["silent", "raise", "ignore"] = "raise" -Action = BashAction +Action = Annotated[BashAction, Field(discriminator="action_type")] class WriteFileRequest(BaseModel): diff --git a/rock/admin/entrypoints/sandbox_api.py b/rock/admin/entrypoints/sandbox_api.py index 8e6da7ef..c55d31c4 100644 --- a/rock/admin/entrypoints/sandbox_api.py +++ b/rock/admin/entrypoints/sandbox_api.py @@ -23,6 +23,7 @@ SandboxWriteFileRequest, ) from rock.admin.proto.response import SandboxStartResponse +from rock.deployments.config import DockerDeploymentConfig from rock.sandbox.sandbox_manager import SandboxManager from rock.utils import handle_exceptions @@ -37,20 +38,20 @@ def set_sandbox_manager(service: SandboxManager): @sandbox_router.post("/start") @handle_exceptions(error_message="start sandbox failed") -async def start(config: SandboxStartRequest) -> RockResponse[SandboxStartResponse]: - sandbox_start_response = await sandbox_manager.start(config.transform()) +async def start(request: SandboxStartRequest) -> RockResponse[SandboxStartResponse]: + sandbox_start_response = await sandbox_manager.start(DockerDeploymentConfig.from_request(request)) return RockResponse(result=sandbox_start_response) @sandbox_router.post("/start_async") @handle_exceptions(error_message="async start sandbox failed") async def start_async( - config: SandboxStartRequest, + request: SandboxStartRequest, x_user_id: str | None = Header(default="default", alias="X-User-Id"), x_experiment_id: str | None = Header(default="default", alias="X-Experiment-Id"), ) -> RockResponse[SandboxStartResponse]: sandbox_start_response = await sandbox_manager.start_async( - config.transform(), + DockerDeploymentConfig.from_request(request), user_info={"user_id": x_user_id, "experiment_id": x_experiment_id}, ) return RockResponse(result=sandbox_start_response) @@ -83,37 +84,37 @@ async def get_status(sandbox_id: str): @sandbox_router.post("/execute") @handle_exceptions(error_message="execute command failed") async def execute(command: SandboxCommand) -> RockResponse[CommandResponse]: - return RockResponse(result=await sandbox_manager.execute(command.transform())) + return RockResponse(result=await sandbox_manager.execute(command)) @sandbox_router.post("/create_session") @handle_exceptions(error_message="create session failed") async def create_session(request: SandboxCreateBashSessionRequest) -> RockResponse[CreateBashSessionResponse]: - return RockResponse(result=await sandbox_manager.create_session(request.transform())) + return RockResponse(result=await sandbox_manager.create_session(request)) @sandbox_router.post("/run_in_session") @handle_exceptions(error_message="run in session failed") async def run(action: SandboxBashAction) -> RockResponse[BashObservation]: - return RockResponse(result=await sandbox_manager.run_in_session(action.transform())) + return RockResponse(result=await sandbox_manager.run_in_session(action)) @sandbox_router.post("/close_session") @handle_exceptions(error_message="close session failed") async def close_session(request: SandboxCloseBashSessionRequest) -> RockResponse[CloseBashSessionResponse]: - return RockResponse(result=await sandbox_manager.close_session(request.transform())) + return RockResponse(result=await sandbox_manager.close_session(request)) @sandbox_router.post("/read_file") @handle_exceptions(error_message="read file failed") async def read_file(request: SandboxReadFileRequest) -> RockResponse[ReadFileResponse]: - return RockResponse(result=await sandbox_manager.read_file(request.transform())) + return RockResponse(result=await sandbox_manager.read_file(request)) @sandbox_router.post("/write_file") @handle_exceptions(error_message="write file failed") async def write_file(request: SandboxWriteFileRequest) -> RockResponse[WriteFileResponse]: - return RockResponse(result=await sandbox_manager.write_file(request.transform())) + return RockResponse(result=await sandbox_manager.write_file(request)) @sandbox_router.post("/upload") diff --git a/rock/admin/entrypoints/sandbox_proxy_api.py b/rock/admin/entrypoints/sandbox_proxy_api.py index 360736d1..dcec35d3 100644 --- a/rock/admin/entrypoints/sandbox_proxy_api.py +++ b/rock/admin/entrypoints/sandbox_proxy_api.py @@ -37,19 +37,19 @@ def set_sandbox_proxy_service(service: SandboxProxyService): @sandbox_proxy_router.post("/execute") @handle_exceptions(error_message="execute command failed") async def execute(command: SandboxCommand) -> RockResponse[CommandResponse]: - return RockResponse(result=await sandbox_proxy_service.execute(command.transform())) + return RockResponse(result=await sandbox_proxy_service.execute(command)) @sandbox_proxy_router.post("/create_session") @handle_exceptions(error_message="create session failed") async def create_session(request: SandboxCreateBashSessionRequest) -> RockResponse[CreateBashSessionResponse]: - return RockResponse(result=await sandbox_proxy_service.create_session(request.transform())) + return RockResponse(result=await sandbox_proxy_service.create_session(request)) @sandbox_proxy_router.post("/run_in_session") @handle_exceptions(error_message="run in session failed") async def run(action: SandboxBashAction) -> RockResponse[BashObservation]: - result = await sandbox_proxy_service.run_in_session(action.transform()) + result = await sandbox_proxy_service.run_in_session(action) if result.exit_code is not None and result.exit_code == -1: return RockResponse(status=ResponseStatus.FAILED, error=result.failure_reason) return RockResponse(result=result) @@ -58,7 +58,7 @@ async def run(action: SandboxBashAction) -> RockResponse[BashObservation]: @sandbox_proxy_router.post("/close_session") @handle_exceptions(error_message="close session failed") async def close_session(request: SandboxCloseBashSessionRequest) -> RockResponse[CloseBashSessionResponse]: - return RockResponse(result=await sandbox_proxy_service.close_session(request.transform())) + return RockResponse(result=await sandbox_proxy_service.close_session(request)) @sandbox_proxy_router.get("/is_alive") @@ -70,13 +70,13 @@ async def is_alive(sandbox_id: str): @sandbox_proxy_router.post("/read_file") @handle_exceptions(error_message="read file failed") async def read_file(request: SandboxReadFileRequest) -> RockResponse[ReadFileResponse]: - return RockResponse(result=await sandbox_proxy_service.read_file(request.transform())) + return RockResponse(result=await sandbox_proxy_service.read_file(request)) @sandbox_proxy_router.post("/write_file") @handle_exceptions(error_message="write file failed") async def write_file(request: SandboxWriteFileRequest) -> RockResponse[WriteFileResponse]: - return RockResponse(result=await sandbox_proxy_service.write_file(request.transform())) + return RockResponse(result=await sandbox_proxy_service.write_file(request)) @sandbox_proxy_router.post("/upload") diff --git a/rock/admin/proto/request.py b/rock/admin/proto/request.py index 385520f5..82802a99 100644 --- a/rock/admin/proto/request.py +++ b/rock/admin/proto/request.py @@ -1,6 +1,6 @@ -from typing import Literal +from typing import Annotated, Literal -from pydantic import BaseModel +from pydantic import BaseModel, Field from rock import env_vars from rock.actions import ( @@ -25,12 +25,6 @@ class SandboxStartRequest(BaseModel): cpus: float = 2 """The amount of CPUs to allocate for the container.""" - def transform(self): - from rock.deployments.config import DockerDeploymentConfig - - res = DockerDeploymentConfig(**self.model_dump()) - return res - class SandboxCommand(Command): timeout: float | None = 1200 @@ -52,25 +46,14 @@ class SandboxCommand(Command): sandbox_id: str | None = None """The id of the sandbox.""" - def transform(self): - from rock.rocklet.proto.request import InternalCommand - - res = InternalCommand(**self.model_dump()) - res.container_name = self.sandbox_id - return res - class SandboxCreateBashSessionRequest(CreateBashSessionRequest): startup_timeout: float = 1.0 max_read_size: int = 2000 sandbox_id: str | None = None - def transform(self): - from rock.rocklet.proto.request import InternalCreateBashSessionRequest - res = InternalCreateBashSessionRequest(**self.model_dump()) - res.container_name = self.sandbox_id - return res +SandboxCreateSessionRequest = Annotated[SandboxCreateBashSessionRequest, Field(discriminator="session_type")] class SandboxBashAction(BashAction): @@ -91,46 +74,24 @@ class SandboxBashAction(BashAction): expect: list[str] = [] """Outputs to expect in addition to the PS1""" - def transform(self): - from rock.rocklet.proto.request import InternalBashAction - res = InternalBashAction(**self.model_dump()) - res.container_name = self.sandbox_id - return res +SandboxAction = Annotated[SandboxBashAction, Field(discriminator="action_type")] class SandboxCloseBashSessionRequest(CloseBashSessionRequest): sandbox_id: str | None = None - def transform(self): - from rock.rocklet.proto.request import InternalCloseBashSessionRequest - res = InternalCloseBashSessionRequest(**self.model_dump()) - res.container_name = self.sandbox_id - return res +SandboxCloseSessionRequest = Annotated[SandboxCloseBashSessionRequest, Field(discriminator="session_type")] class SandboxReadFileRequest(ReadFileRequest): sandbox_id: str | None = None - def transform(self): - from rock.rocklet.proto.request import InternalReadFileRequest - - res = InternalReadFileRequest(**self.model_dump()) - res.container_name = self.sandbox_id - return res - class SandboxWriteFileRequest(WriteFileRequest): sandbox_id: str | None = None - def transform(self): - from rock.rocklet.proto.request import InternalWriteFileRequest - - res = InternalWriteFileRequest(**self.model_dump()) - res.container_name = self.sandbox_id - return res - class WarmupRequest(BaseModel): image: str = "python:3.11" diff --git a/rock/deployments/config.py b/rock/deployments/config.py index 29b55c04..1a88edf3 100644 --- a/rock/deployments/config.py +++ b/rock/deployments/config.py @@ -11,6 +11,7 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator +from rock.admin.proto.request import SandboxStartRequest from rock.config import RuntimeConfig from rock.deployments.abstract import AbstractDeployment from rock.utils import REQUEST_TIMEOUT_SECONDS @@ -149,6 +150,11 @@ def get_deployment(self) -> AbstractDeployment: def auto_clear_time(self) -> int: return self.auto_clear_time_minutes + @classmethod + def from_request(cls, request: SandboxStartRequest) -> DeploymentConfig: + """Create DockerDeploymentConfig from SandboxStartRequest""" + return cls(**request.model_dump()) + class RayDeploymentConfig(DockerDeploymentConfig): """Configuration for Ray-based distributed deployment.""" diff --git a/rock/rocklet/local_api.py b/rock/rocklet/local_api.py index ed119f79..7513aad7 100644 --- a/rock/rocklet/local_api.py +++ b/rock/rocklet/local_api.py @@ -18,13 +18,13 @@ EnvStepResponse, UploadResponse, ) +from rock.admin.proto.request import SandboxAction as Action +from rock.admin.proto.request import SandboxCloseSessionRequest as CloseSessionRequest +from rock.admin.proto.request import SandboxCommand as Command +from rock.admin.proto.request import SandboxCreateSessionRequest as CreateSessionRequest +from rock.admin.proto.request import SandboxReadFileRequest as ReadFileRequest +from rock.admin.proto.request import SandboxWriteFileRequest as WriteFileRequest from rock.rocklet.local_sandbox import LocalSandboxRuntime -from rock.rocklet.proto.request import InternalAction as Action -from rock.rocklet.proto.request import InternalCloseSessionRequest as CloseSessionRequest -from rock.rocklet.proto.request import InternalCommand as Command -from rock.rocklet.proto.request import InternalCreateSessionRequest as CreateSessionRequest -from rock.rocklet.proto.request import InternalReadFileRequest as ReadFileRequest -from rock.rocklet.proto.request import InternalWriteFileRequest as WriteFileRequest from rock.utils import get_executor local_router = APIRouter() diff --git a/rock/rocklet/local_sandbox.py b/rock/rocklet/local_sandbox.py index aec3ebf6..1400467f 100644 --- a/rock/rocklet/local_sandbox.py +++ b/rock/rocklet/local_sandbox.py @@ -34,10 +34,21 @@ IsAliveResponse, LocalSandboxRuntimeConfig, Observation, + ReadFileRequest, ReadFileResponse, + UploadRequest, UploadResponse, + WriteFileRequest, WriteFileResponse, ) +from rock.admin.proto.request import SandboxAction as Action +from rock.admin.proto.request import SandboxBashAction as BashAction +from rock.admin.proto.request import SandboxCloseSessionRequest as CloseSessionRequest +from rock.admin.proto.request import SandboxCommand as Command +from rock.admin.proto.request import SandboxCreateBashSessionRequest as CreateBashSessionRequest +from rock.admin.proto.request import SandboxCreateSessionRequest as CreateSessionRequest +from rock.admin.proto.request import SandboxReadFileRequest as ReadFileRequest +from rock.admin.proto.request import SandboxWriteFileRequest as WriteFileRequest from rock.logger import init_logger from rock.rocklet.exceptions import ( BashIncorrectSyntaxError, @@ -48,16 +59,6 @@ SessionExistsError, SessionNotInitializedError, ) -from rock.rocklet.proto.request import BashInterruptAction -from rock.rocklet.proto.request import InternalAction as Action -from rock.rocklet.proto.request import InternalBashAction as BashAction -from rock.rocklet.proto.request import InternalCloseSessionRequest as CloseSessionRequest -from rock.rocklet.proto.request import InternalCommand as Command -from rock.rocklet.proto.request import InternalCreateBashSessionRequest as CreateBashSessionRequest -from rock.rocklet.proto.request import InternalCreateSessionRequest as CreateSessionRequest -from rock.rocklet.proto.request import InternalReadFileRequest as ReadFileRequest -from rock.rocklet.proto.request import InternalUploadRequest as UploadRequest -from rock.rocklet.proto.request import InternalWriteFileRequest as WriteFileRequest from rock.utils import get_executor __all__ = ["LocalSandboxRuntime", "BashSession"] @@ -207,43 +208,7 @@ def _eat_following_output(self, timeout: float = 0.5) -> str: return "" return _strip_control_chars(output) - async def async_interrupt(self, action: BashInterruptAction) -> BashObservation: - loop = asyncio.get_running_loop() - return await loop.run_in_executor(self._executor, self.interrupt, action) - - def interrupt(self, action: BashInterruptAction) -> BashObservation: - """Interrupt the session.""" - output = "" - for _ in range(action.n_retry): - self.shell.sendintr() - expect_strings = action.expect + [self._ps1] - try: - expect_index = self.shell.expect(expect_strings, timeout=action.timeout) # type: ignore - matched_expect_string = expect_strings[expect_index] - except Exception: - time.sleep(0.2) - continue - output += _strip_control_chars(self.shell.before) # type: ignore - output += self._eat_following_output() - output = output.strip() - return BashObservation(output=output, exit_code=0, expect_string=matched_expect_string) - # Fall back to putting job to background and killing it there: - try: - self.shell.sendcontrol("z") - self.shell.expect(expect_strings, timeout=action.timeout) - output += self.shell.before - self.shell.sendline("kill -9 %1") - expect_index = self.shell.expect(expect_strings, timeout=action.timeout) # type: ignore - matched_expect_string = expect_strings[expect_index] - output += self.shell.before - output += self._eat_following_output() - output = output.strip() - return BashObservation(output=output, exit_code=0, expect_string=matched_expect_string) - except pexpect.TIMEOUT: - msg = "Failed to interrupt session" - raise pexpect.TIMEOUT(msg) - - async def run(self, action: BashAction | BashInterruptAction) -> BashObservation: + async def run(self, action: BashAction) -> BashObservation: """Run a bash action. Raises: @@ -258,8 +223,6 @@ async def run(self, action: BashAction | BashInterruptAction) -> BashObservation if self.shell is None: msg = "shell not initialized" raise SessionNotInitializedError(msg) - if isinstance(action, BashInterruptAction): - return await self.async_interrupt(action) if action.is_interactive_command or action.is_interactive_quit: return await self._aync_run_interactive(action) r = await self._async_run_normal(action) diff --git a/rock/rocklet/proto/request.py b/rock/rocklet/proto/request.py deleted file mode 100644 index a1d250d3..00000000 --- a/rock/rocklet/proto/request.py +++ /dev/null @@ -1,66 +0,0 @@ -from typing import Annotated, Literal - -from pydantic import BaseModel, Field - -from rock.actions import UploadRequest -from rock.admin.proto.request import ( - SandboxBashAction, - SandboxCloseBashSessionRequest, - SandboxCommand, - SandboxCreateBashSessionRequest, - SandboxReadFileRequest, - SandboxWriteFileRequest, -) - - -class InternalCommand(SandboxCommand): - container_name: str | None = None - - -class InternalCreateBashSessionRequest(SandboxCreateBashSessionRequest): - container_name: str | None = None - - -InternalCreateSessionRequest = Annotated[InternalCreateBashSessionRequest, Field(discriminator="session_type")] - - -class InternalBashAction(SandboxBashAction): - container_name: str | None = None - - -InternalAction = InternalBashAction - - -class InternalCloseBashSessionRequest(SandboxCloseBashSessionRequest): - container_name: str | None = None - - -InternalCloseSessionRequest = Annotated[InternalCloseBashSessionRequest, Field(discriminator="session_type")] - - -class InternalReadFileRequest(SandboxReadFileRequest): - container_name: str | None = None - - -class InternalWriteFileRequest(SandboxWriteFileRequest): - container_name: str | None = None - - -InternalUploadRequest = UploadRequest - - -class BashInterruptAction(BaseModel): - command: str = "interrupt" - - session: str = "default" - - timeout: float = 0.2 - """The timeout for the command. None means no timeout.""" - - n_retry: int = 3 - """How many times to retry quitting.""" - - expect: list[str] = [] - """Outputs to expect in addition to the PS1""" - - action_type: Literal["bash_interrupt"] = "bash_interrupt" diff --git a/rock/sandbox/remote_sandbox.py b/rock/sandbox/remote_sandbox.py index 293159e3..11968e7a 100644 --- a/rock/sandbox/remote_sandbox.py +++ b/rock/sandbox/remote_sandbox.py @@ -16,7 +16,6 @@ AbstractSandbox, CloseResponse, CloseSessionResponse, - Command, CommandResponse, CreateSessionResponse, EnvCloseRequest, @@ -32,19 +31,19 @@ Observation, ReadFileResponse, RemoteSandboxRuntimeConfig, + UploadRequest, UploadResponse, WriteFileResponse, _ExceptionTransfer, ) +from rock.admin.proto.request import SandboxAction as Action +from rock.admin.proto.request import SandboxCloseSessionRequest as CloseSessionRequest +from rock.admin.proto.request import SandboxCommand as Command +from rock.admin.proto.request import SandboxCreateSessionRequest as CreateSessionRequest +from rock.admin.proto.request import SandboxReadFileRequest as ReadFileRequest +from rock.admin.proto.request import SandboxWriteFileRequest as WriteFileRequest from rock.logger import init_logger from rock.rocklet.exceptions import RockletException -from rock.rocklet.proto.request import InternalAction as Action -from rock.rocklet.proto.request import InternalCloseSessionRequest as CloseSessionRequest -from rock.rocklet.proto.request import InternalCommand as Command -from rock.rocklet.proto.request import InternalCreateSessionRequest as CreateSessionRequest -from rock.rocklet.proto.request import InternalReadFileRequest as ReadFileRequest -from rock.rocklet.proto.request import InternalUploadRequest as UploadRequest -from rock.rocklet.proto.request import InternalWriteFileRequest as WriteFileRequest from rock.utils import SANDBOX_ID, sandbox_id_ctx_var, wait_until_alive __all__ = ["RemoteSandboxRuntime", "RemoteSandboxRuntimeConfig"] diff --git a/rock/sandbox/sandbox_actor.py b/rock/sandbox/sandbox_actor.py index 3fbad6a1..bed58ff4 100644 --- a/rock/sandbox/sandbox_actor.py +++ b/rock/sandbox/sandbox_actor.py @@ -18,19 +18,18 @@ UploadResponse, WriteFileResponse, ) +from rock.admin.proto.request import SandboxBashAction as BashAction +from rock.admin.proto.request import SandboxCloseBashSessionRequest as CloseBashSessionRequest +from rock.admin.proto.request import SandboxCommand as Command +from rock.admin.proto.request import SandboxCreateBashSessionRequest as CreateBashSessionRequest +from rock.admin.proto.request import SandboxReadFileRequest as ReadFileRequest +from rock.admin.proto.request import SandboxWriteFileRequest as WriteFileRequest from rock.deployments.abstract import AbstractDeployment from rock.deployments.config import DeploymentConfig from rock.deployments.constants import Status from rock.deployments.docker import DockerDeployment from rock.deployments.status import ServiceStatus from rock.logger import init_logger -from rock.rocklet.proto.request import BashInterruptAction -from rock.rocklet.proto.request import InternalBashAction as BashAction -from rock.rocklet.proto.request import InternalCloseBashSessionRequest as CloseBashSessionRequest -from rock.rocklet.proto.request import InternalCommand as Command -from rock.rocklet.proto.request import InternalCreateBashSessionRequest as CreateBashSessionRequest -from rock.rocklet.proto.request import InternalReadFileRequest as ReadFileRequest -from rock.rocklet.proto.request import InternalWriteFileRequest as WriteFileRequest from rock.sandbox.gem_actor import GemActor logger = init_logger(__name__) @@ -185,7 +184,7 @@ async def create_session(self, request: CreateBashSessionRequest) -> CreateBashS await self._refresh_stop_time() return await self._deployment.runtime.create_session(request) - async def run_in_session(self, action: BashAction | BashInterruptAction) -> BashObservation: + async def run_in_session(self, action: BashAction) -> BashObservation: await self._refresh_stop_time() return await self._deployment.runtime.run_in_session(action) diff --git a/rock/sandbox/sandbox_manager.py b/rock/sandbox/sandbox_manager.py index ba7295a7..f2586410 100644 --- a/rock/sandbox/sandbox_manager.py +++ b/rock/sandbox/sandbox_manager.py @@ -14,22 +14,21 @@ UploadResponse, WriteFileResponse, ) -from rock.actions.sandbox.request import Action from rock.admin.core.redis_key import ALIVE_PREFIX, alive_sandbox_key, timeout_sandbox_key from rock.admin.metrics.decorator import monitor_sandbox_operation +from rock.admin.proto.request import SandboxAction as Action +from rock.admin.proto.request import SandboxCloseBashSessionRequest as CloseBashSessionRequest +from rock.admin.proto.request import SandboxCommand as Command +from rock.admin.proto.request import SandboxCreateSessionRequest as CreateSessionRequest +from rock.admin.proto.request import SandboxReadFileRequest as ReadFileRequest +from rock.admin.proto.request import SandboxWriteFileRequest as WriteFileRequest from rock.admin.proto.response import SandboxStartResponse, SandboxStatusResponse from rock.config import RockConfig -from rock.deployments.config import DeploymentConfig +from rock.deployments.config import DeploymentConfig, DockerDeploymentConfig from rock.deployments.constants import Status from rock.deployments.status import PhaseStatus, ServiceStatus from rock.logger import init_logger from rock.rocklet import __version__ as swe_version -from rock.rocklet.proto.request import BashInterruptAction -from rock.rocklet.proto.request import InternalCloseBashSessionRequest as CloseBashSessionRequest -from rock.rocklet.proto.request import InternalCommand as Command -from rock.rocklet.proto.request import InternalCreateSessionRequest as CreateSessionRequest -from rock.rocklet.proto.request import InternalReadFileRequest as ReadFileRequest -from rock.rocklet.proto.request import InternalWriteFileRequest as WriteFileRequest from rock.sandbox import __version__ as gateway_version from rock.sandbox.base_manager import BaseManager from rock.sandbox.sandbox_actor import SandboxActor @@ -83,7 +82,7 @@ async def async_ray_get_actor(self, sandbox_id: str): @monitor_sandbox_operation() async def start_async(self, config: DeploymentConfig, user_info: dict = {}) -> SandboxStartResponse: - docker_deployment_config = await self.deployment_manager.init_config(config) + docker_deployment_config: DockerDeploymentConfig = await self.deployment_manager.init_config(config) sandbox_id = docker_deployment_config.container_name actor_name = self.deployment_manager.get_actor_name(sandbox_id) @@ -117,7 +116,7 @@ async def start_async(self, config: DeploymentConfig, user_info: dict = {}) -> S @monitor_sandbox_operation() async def start(self, config: DeploymentConfig) -> SandboxStartResponse: - docker_deployment_config = await self.deployment_manager.init_config(config) + docker_deployment_config: DockerDeploymentConfig = await self.deployment_manager.init_config(config) sandbox_id = docker_deployment_config.container_name actor_name = self.deployment_manager.get_actor_name(sandbox_id) @@ -232,47 +231,47 @@ async def get_status(self, sandbox_id) -> SandboxStatusResponse: ) async def create_session(self, request: CreateSessionRequest) -> CreateBashSessionResponse: - sandbox_actor = await self.async_ray_get_actor(request.container_name) + sandbox_actor = await self.async_ray_get_actor(request.sandbox_id) if sandbox_actor is None: - raise Exception(f"sandbox {request.container_name} not found to create session") - await self._update_expire_time(request.container_name) + raise Exception(f"sandbox {request.sandbox_id} not found to create session") + await self._update_expire_time(request.sandbox_id) return await self.async_ray_get(sandbox_actor.create_session.remote(request)) @monitor_sandbox_operation() - async def run_in_session(self, action: Action | BashInterruptAction) -> BashObservation: - sandbox_actor = await self.async_ray_get_actor(action.container_name) + async def run_in_session(self, action: Action) -> BashObservation: + sandbox_actor = await self.async_ray_get_actor(action.sandbox_id) if sandbox_actor is None: - raise Exception(f"sandbox {action.container_name} not found to run in session") - await self._update_expire_time(action.container_name) + raise Exception(f"sandbox {action.sandbox_id} not found to run in session") + await self._update_expire_time(action.sandbox_id) return await self.async_ray_get(sandbox_actor.run_in_session.remote(action)) async def close_session(self, request: CloseBashSessionRequest) -> CloseBashSessionResponse: - sandbox_actor = await self.async_ray_get_actor(request.container_name) + sandbox_actor = await self.async_ray_get_actor(request.sandbox_id) if sandbox_actor is None: - raise Exception(f"sandbox {request.container_name} not found to close session") - await self._update_expire_time(request.container_name) + raise Exception(f"sandbox {request.sandbox_id} not found to close session") + await self._update_expire_time(request.sandbox_id) return await self.async_ray_get(sandbox_actor.close_session.remote(request)) async def execute(self, command: Command) -> CommandResponse: - sandbox_actor = await self.async_ray_get_actor(command.container_name) + sandbox_actor = await self.async_ray_get_actor(command.sandbox_id) if sandbox_actor is None: - raise Exception(f"sandbox {command.container_name} not found to execute") - await self._update_expire_time(command.container_name) + raise Exception(f"sandbox {command.sandbox_id} not found to execute") + await self._update_expire_time(command.sandbox_id) return await self.async_ray_get(sandbox_actor.execute.remote(command)) async def read_file(self, request: ReadFileRequest) -> ReadFileResponse: - sandbox_actor = await self.async_ray_get_actor(request.container_name) + sandbox_actor = await self.async_ray_get_actor(request.sandbox_id) if sandbox_actor is None: - raise Exception(f"sandbox {request.container_name} not found to read file") - await self._update_expire_time(request.container_name) + raise Exception(f"sandbox {request.sandbox_id} not found to read file") + await self._update_expire_time(request.sandbox_id) return await self.async_ray_get(sandbox_actor.read_file.remote(request)) @monitor_sandbox_operation() async def write_file(self, request: WriteFileRequest) -> WriteFileResponse: - sandbox_actor = await self.async_ray_get_actor(request.container_name) + sandbox_actor = await self.async_ray_get_actor(request.sandbox_id) if sandbox_actor is None: - raise Exception(f"sandbox {request.container_name} not found to write file") - await self._update_expire_time(request.container_name) + raise Exception(f"sandbox {request.sandbox_id} not found to write file") + await self._update_expire_time(request.sandbox_id) return await self.async_ray_get(sandbox_actor.write_file.remote(request)) @monitor_sandbox_operation() diff --git a/rock/sandbox/service/sandbox_proxy_service.py b/rock/sandbox/service/sandbox_proxy_service.py index e8ad8aca..bdab17e4 100644 --- a/rock/sandbox/service/sandbox_proxy_service.py +++ b/rock/sandbox/service/sandbox_proxy_service.py @@ -23,17 +23,16 @@ from rock.admin.core.redis_key import alive_sandbox_key, timeout_sandbox_key from rock.admin.metrics.decorator import monitor_sandbox_operation from rock.admin.metrics.monitor import MetricsMonitor +from rock.admin.proto.request import SandboxBashAction as BashAction +from rock.admin.proto.request import SandboxCloseBashSessionRequest as CloseBashSessionRequest +from rock.admin.proto.request import SandboxCommand as Command +from rock.admin.proto.request import SandboxCreateSessionRequest as CreateSessionRequest +from rock.admin.proto.request import SandboxReadFileRequest as ReadFileRequest +from rock.admin.proto.request import SandboxWriteFileRequest as WriteFileRequest from rock.config import OssConfig, ProxyServiceConfig, RockConfig from rock.deployments.constants import Port from rock.deployments.status import ServiceStatus from rock.logger import init_logger -from rock.rocklet.proto.request import BashInterruptAction -from rock.rocklet.proto.request import InternalBashAction as BashAction -from rock.rocklet.proto.request import InternalCloseBashSessionRequest as CloseBashSessionRequest -from rock.rocklet.proto.request import InternalCommand as Command -from rock.rocklet.proto.request import InternalCreateSessionRequest as CreateSessionRequest -from rock.rocklet.proto.request import InternalReadFileRequest as ReadFileRequest -from rock.rocklet.proto.request import InternalWriteFileRequest as WriteFileRequest from rock.utils.providers import RedisProvider logger = init_logger(__name__) @@ -66,7 +65,7 @@ def __init__(self, rock_config: RockConfig, redis_provider: RedisProvider | None @monitor_sandbox_operation() async def create_session(self, request: CreateSessionRequest) -> CreateBashSessionResponse: - sandbox_id = request.container_name + sandbox_id = request.sandbox_id await self._update_expire_time(sandbox_id) sandbox_status_dicts = await self.get_service_status(sandbox_id) response = await self._send_request( @@ -75,8 +74,8 @@ async def create_session(self, request: CreateSessionRequest) -> CreateBashSessi return CreateBashSessionResponse(**response) @monitor_sandbox_operation() - async def run_in_session(self, action: BashAction | BashInterruptAction) -> BashObservation: - sandbox_id = action.container_name + async def run_in_session(self, action: BashAction) -> BashObservation: + sandbox_id = action.sandbox_id await self._update_expire_time(sandbox_id) sandbox_status_dicts = await self.get_service_status(sandbox_id) response = await self._send_request( @@ -86,7 +85,7 @@ async def run_in_session(self, action: BashAction | BashInterruptAction) -> Bash @monitor_sandbox_operation() async def close_session(self, request: CloseBashSessionRequest) -> CloseBashSessionResponse: - sandbox_id = request.container_name + sandbox_id = request.sandbox_id await self._update_expire_time(sandbox_id) sandbox_status_dicts = await self.get_service_status(sandbox_id) response = await self._send_request( @@ -102,7 +101,7 @@ async def is_alive(self, sandbox_id: str) -> IsAliveResponse: @monitor_sandbox_operation() async def read_file(self, request: ReadFileRequest) -> ReadFileResponse: - sandbox_id = request.container_name + sandbox_id = request.sandbox_id await self._update_expire_time(sandbox_id) sandbox_status_dicts = await self.get_service_status(sandbox_id) response = await self._send_request( @@ -112,7 +111,7 @@ async def read_file(self, request: ReadFileRequest) -> ReadFileResponse: @monitor_sandbox_operation() async def write_file(self, request: WriteFileRequest) -> WriteFileResponse: - sandbox_id = request.container_name + sandbox_id = request.sandbox_id await self._update_expire_time(sandbox_id) sandbox_status_dicts = await self.get_service_status(sandbox_id) response = await self._send_request( @@ -131,7 +130,7 @@ async def upload(self, file: UploadFile, target_path: str, sandbox_id: str) -> U @monitor_sandbox_operation() async def execute(self, command: Command) -> CommandResponse: - sandbox_id = command.container_name + sandbox_id = command.sandbox_id await self._update_expire_time(sandbox_id) sandbox_status_dicts = await self.get_service_status(sandbox_id) response = await self._send_request( diff --git a/rock/sdk/sandbox/client.py b/rock/sdk/sandbox/client.py index 949f0f13..14d2378a 100644 --- a/rock/sdk/sandbox/client.py +++ b/rock/sdk/sandbox/client.py @@ -13,7 +13,9 @@ from rock import env_vars from rock.actions import ( + AbstractSandbox, Action, + BashAction, CloseResponse, CloseSessionRequest, CloseSessionResponse, @@ -33,7 +35,6 @@ WriteFileRequest, WriteFileResponse, ) -from rock.actions.sandbox.base import AbstractSandbox from rock.sdk.common.constants import RunModeType from rock.sdk.sandbox.config import SandboxConfig, SandboxGroupConfig from rock.utils import HttpUtils, extract_nohup_pid, retry_async @@ -264,7 +265,7 @@ async def run_nohup_and_wait( # Build and execute nohup command nohup_command = f"nohup {cmd} < /dev/null > {redirect_file_path} 2>&1 & echo $!;disown" - action = Action(command=nohup_command, session=temp_session) + action = BashAction(command=nohup_command, session=temp_session) response = await self._run_in_session(action) # Parse @@ -298,7 +299,7 @@ async def arun( session = temp_session tmp_file = f"/tmp/tmp_{timestamp}.out" nohup_command = f"nohup {cmd} < /dev/null > {tmp_file} 2>&1 & echo $!;disown" - action = Action(command=nohup_command, session=session) + action = BashAction(command=nohup_command, session=session) response: Observation = await self._run_in_session(action) pid = extract_nohup_pid(response.output) @@ -313,7 +314,7 @@ async def arun( pid=pid, session=session, wait_timeout=wait_timeout, wait_interval=wait_interval ) exec_result: Observation = await self._run_in_session( - Action(session=session, command=f"cat {tmp_file}") + BashAction(session=session, command=f"cat {tmp_file}") ) if success: return Observation(output=exec_result.output, exit_code=0) @@ -324,7 +325,7 @@ async def arun( error_msg = f"Failed to execute nohup command '{cmd}': {str(e)}" return Observation(output="", exit_code=1, failure_reason=error_msg) elif mode == "normal": - return await self._run_in_session(action=Action(command=cmd, session=session)) + return await self._run_in_session(action=BashAction(command=cmd, session=session)) else: return Observation(output="", exit_code=1, failure_reason="Unsupported arun mode") @@ -371,7 +372,7 @@ async def _wait_for_process_completion( try: # Check if process still exists await asyncio.wait_for( - self.run_in_session(Action(session=session, command=check_alive_cmd)), + self.run_in_session(BashAction(session=session, command=check_alive_cmd)), timeout=check_alive_timeout, ) @@ -549,7 +550,7 @@ async def _upload_via_oss(self, file_path: str | Path, target_path: str): await self.create_session(CreateBashSessionRequest(session=check_file_session)) check_file_cmd = f"test -f {target_path}" check_response: Observation = await self.run_in_session( - action=Action(command=check_file_cmd, session=check_file_session) + action=BashAction(command=check_file_cmd, session=check_file_session) ) if not check_response.exit_code == 0: return UploadResponse( diff --git a/tests/deployments/test_local_deployment.py b/tests/deployments/test_local_deployment.py index cf07d921..cc75df71 100644 --- a/tests/deployments/test_local_deployment.py +++ b/tests/deployments/test_local_deployment.py @@ -2,9 +2,9 @@ import pytest +from rock.admin.proto.request import SandboxBashAction as BashAction +from rock.admin.proto.request import SandboxCreateBashSessionRequest as CreateBashSessionRequest from rock.deployments.local import LocalDeployment -from rock.rocklet.proto.request import InternalBashAction as BashAction -from rock.rocklet.proto.request import InternalCreateBashSessionRequest as CreateBashSessionRequest @pytest.mark.asyncio diff --git a/tests/rocklet/test_local_sandbox_runtime.py b/tests/rocklet/test_local_sandbox_runtime.py index c0f7c80c..b2e7d73e 100644 --- a/tests/rocklet/test_local_sandbox_runtime.py +++ b/tests/rocklet/test_local_sandbox_runtime.py @@ -4,16 +4,12 @@ import pytest from gem.envs.game_env.sokoban import SokobanEnv -from rock.actions import ( - EnvMakeResponse, - EnvStepResponse, -) +from rock.actions import EnvMakeResponse, EnvStepResponse, UploadRequest +from rock.admin.proto.request import SandboxBashAction as BashAction +from rock.admin.proto.request import SandboxCloseBashSessionRequest as CloseBashSessionRequest +from rock.admin.proto.request import SandboxCreateBashSessionRequest as CreateBashSessionRequest +from rock.admin.proto.request import SandboxReadFileRequest as ReadFileRequest from rock.rocklet.local_sandbox import LocalSandboxRuntime -from rock.rocklet.proto.request import InternalBashAction as BashAction -from rock.rocklet.proto.request import InternalCloseBashSessionRequest as CloseBashSessionRequest -from rock.rocklet.proto.request import InternalCreateBashSessionRequest as CreateBashSessionRequest -from rock.rocklet.proto.request import InternalReadFileRequest as ReadFileRequest -from rock.rocklet.proto.request import InternalUploadRequest as UploadRequest @pytest.fixture diff --git a/tests/utils/test_shell_util.py b/tests/utils/test_shell_util.py index 1356e3b2..67ea8863 100644 --- a/tests/utils/test_shell_util.py +++ b/tests/utils/test_shell_util.py @@ -1,6 +1,6 @@ +from rock.admin.proto.request import SandboxBashAction as BashAction +from rock.admin.proto.request import SandboxCreateBashSessionRequest as CreateBashSessionRequest from rock.deployments.local import LocalDeployment -from rock.rocklet.proto.request import InternalBashAction as BashAction -from rock.rocklet.proto.request import InternalCreateBashSessionRequest as CreateBashSessionRequest from rock.utils import extract_nohup_pid diff --git a/uv.lock b/uv.lock index dcd54697..8603143f 100644 --- a/uv.lock +++ b/uv.lock @@ -3995,7 +3995,7 @@ wheels = [ [[package]] name = "rl-rock" -version = "0.2.0" +version = "0.2.1" source = { editable = "." } dependencies = [ { name = "anyio" },