Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions rock/actions/sandbox/sandbox_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import TypedDict

from rock.deployments.status import PhaseStatus


class SandboxInfo(TypedDict, total=False):
host_ip: str
host_name: str
image: str
user_id: str
experiment_id: str
namespace: str
sandbox_id: str
auth_token: str
phases: dict[str, PhaseStatus]
port_mapping: dict[int, int]
create_user_gray_flag: bool
cpus: float
memory: str
29 changes: 29 additions & 0 deletions rock/sandbox/sandbox_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
UploadResponse,
WriteFileResponse,
)
from rock.actions.sandbox.sandbox_info import SandboxInfo
from rock.admin.proto.request import SandboxBashAction as BashAction
from rock.admin.proto.request import SandboxCloseBashSessionRequest as CloseBashSessionRequest
from rock.admin.proto.request import SandboxCommand as Command
Expand All @@ -42,6 +43,11 @@ class SandboxActor(GemActor):
_clean_container_background_script = "rock/admin/scripts/clean_container_background.sh"
_clean_container_background_process = None
_metrics_monitor = None
_role = "test"
_env = "dev"
_user_id = "default"
_experiment_id = "default"
_namespace = "default"

def __init__(
self,
Expand Down Expand Up @@ -238,6 +244,9 @@ async def set_user_id(self, user_id: str):
async def set_experiment_id(self, experiment_id: str):
self._experiment_id = experiment_id

async def set_namespace(self, namespace: str):
self._namespace = namespace

async def user_id(self) -> str | None:
if isinstance(self._deployment, DockerDeployment):
return self._user_id
Expand All @@ -247,3 +256,23 @@ async def experiment_id(self) -> str | None:
if isinstance(self._deployment, DockerDeployment):
return self._experiment_id
return None

async def namespace(self) -> str | None:
if isinstance(self._deployment, DockerDeployment):
return self._namespace
return None

async def sandbox_info(self) -> SandboxInfo:
if isinstance(self._deployment, DockerDeployment):
return {
"host_ip": await self.host_ip(),
"host_name": await self.host_name(),
"image": self._config.image,
"user_id": await self.user_id(),
"experiment_id": await self.experiment_id(),
"sandbox_id": self._config.container_name,
"namespace": await self.namespace(),
"cpus": self._config.cpus,
"memory": self._config.memory,
}
return {}
77 changes: 42 additions & 35 deletions rock/sandbox/sandbox_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
UploadResponse,
WriteFileResponse,
)
from rock.actions.sandbox.sandbox_info import SandboxInfo
from rock.admin.core.redis_key import ALIVE_PREFIX, alive_sandbox_key, timeout_sandbox_key
from rock.admin.metrics.decorator import monitor_sandbox_operation
from rock.admin.proto.request import SandboxAction as Action
Expand Down Expand Up @@ -90,28 +91,32 @@ async def start_async(self, config: DeploymentConfig, user_info: dict = {}) -> S
deployment = docker_deployment_config.get_deployment()

sandbox_actor: SandboxActor = await deployment.creator_actor(actor_name)
user_id = user_info.get("user_id", "default")
experiment_id = user_info.get("experiment_id", "default")
namespace = user_info.get("namespace", "default")
sandbox_actor.start.remote()
sandbox_actor.set_user_id.remote(user_info.get("user_id", "default"))
sandbox_actor.set_experiment_id.remote(user_info.get("experiment_id", "default"))
sandbox_actor.set_user_id.remote(user_id)
sandbox_actor.set_experiment_id.remote(experiment_id)
sandbox_actor.set_namespace.remote(namespace)

self._sandbox_meta[sandbox_id] = {"image": docker_deployment_config.image}
logger.info(f"sandbox {sandbox_id} is submitted")
stop_time = str(int(time.time()) + docker_deployment_config.auto_clear_time * 60)
auto_clear_time_dict = {
env_vars.ROCK_SANDBOX_AUTO_CLEAR_TIME_KEY: str(docker_deployment_config.auto_clear_time),
env_vars.ROCK_SANDBOX_EXPIRE_TIME_KEY: stop_time,
}
sandbox_info: SandboxInfo = await self.async_ray_get(sandbox_actor.sandbox_info.remote())
sandbox_info["user_id"] = user_id
sandbox_info["experiment_id"] = experiment_id
sandbox_info["namespace"] = namespace
if self._redis_provider:
sandbox_info = {
SANDBOX_ID: sandbox_id,
"user_id": await sandbox_actor.user_id.remote(),
"experiment_id": await sandbox_actor.experiment_id.remote(),
}
await self._redis_provider.json_set(alive_sandbox_key(sandbox_id), "$", sandbox_info)
await self._redis_provider.json_set(timeout_sandbox_key(sandbox_id), "$", auto_clear_time_dict)
return SandboxStartResponse(
sandbox_id=sandbox_id,
host_name=await self.async_ray_get(sandbox_actor.host_name.remote()),
host_ip=await self.async_ray_get(sandbox_actor.host_ip.remote()),
host_name=sandbox_info.get("host_name"),
host_ip=sandbox_info.get("host_ip"),
)

@monitor_sandbox_operation()
Expand Down Expand Up @@ -186,6 +191,13 @@ async def _clear_redis_keys(self, sandbox_id):
await self._redis_provider.json_delete(timeout_sandbox_key(sandbox_id))
logger.info(f"sandbox {sandbox_id} deleted from redis")

async def build_sandbox_from_redis(self, sandbox_id: str) -> SandboxInfo | None:
if self._redis_provider:
sandbox_status = await self._redis_provider.json_get(alive_sandbox_key(sandbox_id), "$")
if sandbox_status and len(sandbox_status) > 0:
return sandbox_status[0]
return None

@monitor_sandbox_operation()
async def get_status(self, sandbox_id) -> SandboxStatusResponse:
sandbox_actor = await self.async_ray_get_actor(sandbox_id)
Expand All @@ -194,40 +206,35 @@ async def get_status(self, sandbox_id) -> SandboxStatusResponse:
else:
remote_status: ServiceStatus = await self.async_ray_get(sandbox_actor.get_status.remote())
self.update_service_status(remote_status.phases)
host_name = await self.async_ray_get(sandbox_actor.host_name.remote())
alive = await self.async_ray_get(sandbox_actor.is_alive.remote())
host_ip = await self.async_ray_get(sandbox_actor.host_ip.remote())
config = await self.async_ray_get(sandbox_actor.deployment_config.remote())
image = getattr(config, "image", "local")
try:
user_id = await sandbox_actor.user_id.remote()
experiment_id = await sandbox_actor.experiment_id.remote()
except Exception as e:
logger.warning("get user_id and experiment_id failed", exc_info=e)
user_id = "default"
experiment_id = "default"
if alive and self._redis_provider:
status_dict = remote_status.to_dict()
status_dict["host_ip"] = host_ip
status_dict[SANDBOX_ID] = sandbox_id
status_dict["image"] = image
status_dict["user_id"] = user_id
status_dict["experiment_id"] = experiment_id
await self._redis_provider.json_set(alive_sandbox_key(sandbox_id), "$", status_dict)
sandbox_info: SandboxInfo = None
if self._redis_provider:
sandbox_info = await self.build_sandbox_from_redis(sandbox_id)
if sandbox_info is None:
# The start() method will write to redis on the first call to get_status()
sandbox_info = await self.async_ray_get(sandbox_actor.sandbox_info.remote())
sandbox_info.update(remote_status.to_dict())
await self._redis_provider.json_set(alive_sandbox_key(sandbox_id), "$", sandbox_info)
await self._update_expire_time(sandbox_id)
logger.info(f"sandbox {sandbox_id} status is {remote_status}, write to redis")
self.logger.info(f"sandbox {sandbox_id} status is {remote_status}, write to redis")
else:
sandbox_info = await self.async_ray_get(sandbox_actor.sandbox_info.remote())

alive = await self.async_ray_get(sandbox_actor.is_alive.remote())
return SandboxStatusResponse(
sandbox_id=sandbox_id,
status=self._service_status.phases,
port_mapping=remote_status.get_port_mapping(),
host_name=host_name,
host_ip=host_ip,
host_name=sandbox_info.get("host_name"),
host_ip=sandbox_info.get("host_ip"),
is_alive=alive.is_alive,
image=image,
image=sandbox_info.get("image"),
swe_rex_version=swe_version,
gateway_version=gateway_version,
user_id=user_id,
experiment_id=experiment_id,
user_id=sandbox_info.get("user_id"),
experiment_id=sandbox_info.get("experiment_id"),
namespace=sandbox_info.get("namespace"),
cpus=sandbox_info.get("cpus"),
memory=sandbox_info.get("memory"),
)

async def create_session(self, request: CreateSessionRequest) -> CreateBashSessionResponse:
Expand Down
Loading