diff --git a/rock/actions/sandbox/sandbox_info.py b/rock/actions/sandbox/sandbox_info.py new file mode 100644 index 00000000..51079825 --- /dev/null +++ b/rock/actions/sandbox/sandbox_info.py @@ -0,0 +1,19 @@ +from typing import TypedDict + +from rock.deployments.status import PhaseStatus + + +class SandboxInfo(TypedDict, total=False): + host_ip: str + host_name: str + image: str + user_id: str + experiment_id: str + namespace: str + sandbox_id: str + auth_token: str + phases: dict[str, PhaseStatus] + port_mapping: dict[int, int] + create_user_gray_flag: bool + cpus: float + memory: str diff --git a/rock/sandbox/sandbox_actor.py b/rock/sandbox/sandbox_actor.py index bed58ff4..ef59dc5e 100644 --- a/rock/sandbox/sandbox_actor.py +++ b/rock/sandbox/sandbox_actor.py @@ -18,6 +18,7 @@ UploadResponse, WriteFileResponse, ) +from rock.actions.sandbox.sandbox_info import SandboxInfo from rock.admin.proto.request import SandboxBashAction as BashAction from rock.admin.proto.request import SandboxCloseBashSessionRequest as CloseBashSessionRequest from rock.admin.proto.request import SandboxCommand as Command @@ -42,6 +43,11 @@ class SandboxActor(GemActor): _clean_container_background_script = "rock/admin/scripts/clean_container_background.sh" _clean_container_background_process = None _metrics_monitor = None + _role = "test" + _env = "dev" + _user_id = "default" + _experiment_id = "default" + _namespace = "default" def __init__( self, @@ -238,6 +244,9 @@ async def set_user_id(self, user_id: str): async def set_experiment_id(self, experiment_id: str): self._experiment_id = experiment_id + async def set_namespace(self, namespace: str): + self._namespace = namespace + async def user_id(self) -> str | None: if isinstance(self._deployment, DockerDeployment): return self._user_id @@ -247,3 +256,23 @@ async def experiment_id(self) -> str | None: if isinstance(self._deployment, DockerDeployment): return self._experiment_id return None + + async def namespace(self) -> str | None: + if isinstance(self._deployment, DockerDeployment): + return self._namespace + return None + + async def sandbox_info(self) -> SandboxInfo: + if isinstance(self._deployment, DockerDeployment): + return { + "host_ip": await self.host_ip(), + "host_name": await self.host_name(), + "image": self._config.image, + "user_id": await self.user_id(), + "experiment_id": await self.experiment_id(), + "sandbox_id": self._config.container_name, + "namespace": await self.namespace(), + "cpus": self._config.cpus, + "memory": self._config.memory, + } + return {} diff --git a/rock/sandbox/sandbox_manager.py b/rock/sandbox/sandbox_manager.py index f2586410..c304710d 100644 --- a/rock/sandbox/sandbox_manager.py +++ b/rock/sandbox/sandbox_manager.py @@ -14,6 +14,7 @@ UploadResponse, WriteFileResponse, ) +from rock.actions.sandbox.sandbox_info import SandboxInfo from rock.admin.core.redis_key import ALIVE_PREFIX, alive_sandbox_key, timeout_sandbox_key from rock.admin.metrics.decorator import monitor_sandbox_operation from rock.admin.proto.request import SandboxAction as Action @@ -90,9 +91,14 @@ async def start_async(self, config: DeploymentConfig, user_info: dict = {}) -> S deployment = docker_deployment_config.get_deployment() sandbox_actor: SandboxActor = await deployment.creator_actor(actor_name) + user_id = user_info.get("user_id", "default") + experiment_id = user_info.get("experiment_id", "default") + namespace = user_info.get("namespace", "default") sandbox_actor.start.remote() - sandbox_actor.set_user_id.remote(user_info.get("user_id", "default")) - sandbox_actor.set_experiment_id.remote(user_info.get("experiment_id", "default")) + sandbox_actor.set_user_id.remote(user_id) + sandbox_actor.set_experiment_id.remote(experiment_id) + sandbox_actor.set_namespace.remote(namespace) + self._sandbox_meta[sandbox_id] = {"image": docker_deployment_config.image} logger.info(f"sandbox {sandbox_id} is submitted") stop_time = str(int(time.time()) + docker_deployment_config.auto_clear_time * 60) @@ -100,18 +106,17 @@ async def start_async(self, config: DeploymentConfig, user_info: dict = {}) -> S env_vars.ROCK_SANDBOX_AUTO_CLEAR_TIME_KEY: str(docker_deployment_config.auto_clear_time), env_vars.ROCK_SANDBOX_EXPIRE_TIME_KEY: stop_time, } + sandbox_info: SandboxInfo = await self.async_ray_get(sandbox_actor.sandbox_info.remote()) + sandbox_info["user_id"] = user_id + sandbox_info["experiment_id"] = experiment_id + sandbox_info["namespace"] = namespace if self._redis_provider: - sandbox_info = { - SANDBOX_ID: sandbox_id, - "user_id": await sandbox_actor.user_id.remote(), - "experiment_id": await sandbox_actor.experiment_id.remote(), - } await self._redis_provider.json_set(alive_sandbox_key(sandbox_id), "$", sandbox_info) await self._redis_provider.json_set(timeout_sandbox_key(sandbox_id), "$", auto_clear_time_dict) return SandboxStartResponse( sandbox_id=sandbox_id, - host_name=await self.async_ray_get(sandbox_actor.host_name.remote()), - host_ip=await self.async_ray_get(sandbox_actor.host_ip.remote()), + host_name=sandbox_info.get("host_name"), + host_ip=sandbox_info.get("host_ip"), ) @monitor_sandbox_operation() @@ -186,6 +191,13 @@ async def _clear_redis_keys(self, sandbox_id): await self._redis_provider.json_delete(timeout_sandbox_key(sandbox_id)) logger.info(f"sandbox {sandbox_id} deleted from redis") + async def build_sandbox_from_redis(self, sandbox_id: str) -> SandboxInfo | None: + if self._redis_provider: + sandbox_status = await self._redis_provider.json_get(alive_sandbox_key(sandbox_id), "$") + if sandbox_status and len(sandbox_status) > 0: + return sandbox_status[0] + return None + @monitor_sandbox_operation() async def get_status(self, sandbox_id) -> SandboxStatusResponse: sandbox_actor = await self.async_ray_get_actor(sandbox_id) @@ -194,40 +206,35 @@ async def get_status(self, sandbox_id) -> SandboxStatusResponse: else: remote_status: ServiceStatus = await self.async_ray_get(sandbox_actor.get_status.remote()) self.update_service_status(remote_status.phases) - host_name = await self.async_ray_get(sandbox_actor.host_name.remote()) - alive = await self.async_ray_get(sandbox_actor.is_alive.remote()) - host_ip = await self.async_ray_get(sandbox_actor.host_ip.remote()) - config = await self.async_ray_get(sandbox_actor.deployment_config.remote()) - image = getattr(config, "image", "local") - try: - user_id = await sandbox_actor.user_id.remote() - experiment_id = await sandbox_actor.experiment_id.remote() - except Exception as e: - logger.warning("get user_id and experiment_id failed", exc_info=e) - user_id = "default" - experiment_id = "default" - if alive and self._redis_provider: - status_dict = remote_status.to_dict() - status_dict["host_ip"] = host_ip - status_dict[SANDBOX_ID] = sandbox_id - status_dict["image"] = image - status_dict["user_id"] = user_id - status_dict["experiment_id"] = experiment_id - await self._redis_provider.json_set(alive_sandbox_key(sandbox_id), "$", status_dict) + sandbox_info: SandboxInfo = None + if self._redis_provider: + sandbox_info = await self.build_sandbox_from_redis(sandbox_id) + if sandbox_info is None: + # The start() method will write to redis on the first call to get_status() + sandbox_info = await self.async_ray_get(sandbox_actor.sandbox_info.remote()) + sandbox_info.update(remote_status.to_dict()) + await self._redis_provider.json_set(alive_sandbox_key(sandbox_id), "$", sandbox_info) await self._update_expire_time(sandbox_id) - logger.info(f"sandbox {sandbox_id} status is {remote_status}, write to redis") + self.logger.info(f"sandbox {sandbox_id} status is {remote_status}, write to redis") + else: + sandbox_info = await self.async_ray_get(sandbox_actor.sandbox_info.remote()) + + alive = await self.async_ray_get(sandbox_actor.is_alive.remote()) return SandboxStatusResponse( sandbox_id=sandbox_id, status=self._service_status.phases, port_mapping=remote_status.get_port_mapping(), - host_name=host_name, - host_ip=host_ip, + host_name=sandbox_info.get("host_name"), + host_ip=sandbox_info.get("host_ip"), is_alive=alive.is_alive, - image=image, + image=sandbox_info.get("image"), swe_rex_version=swe_version, gateway_version=gateway_version, - user_id=user_id, - experiment_id=experiment_id, + user_id=sandbox_info.get("user_id"), + experiment_id=sandbox_info.get("experiment_id"), + namespace=sandbox_info.get("namespace"), + cpus=sandbox_info.get("cpus"), + memory=sandbox_info.get("memory"), ) async def create_session(self, request: CreateSessionRequest) -> CreateBashSessionResponse: