Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,90 +22,90 @@
SandboxReadFileRequest,
SandboxWriteFileRequest,
)
from rock.sandbox.service.sandbox_read_service import SandboxReadService
from rock.sandbox.service.sandbox_proxy_service import SandboxProxyService
from rock.utils import handle_exceptions

sandbox_read_router = APIRouter()
sandbox_read_service: SandboxReadService
sandbox_proxy_router = APIRouter()
sandbox_proxy_service: SandboxProxyService


def set_sandbox_read_service(service: SandboxReadService):
global sandbox_read_service
sandbox_read_service = service
def set_sandbox_proxy_service(service: SandboxProxyService):
global sandbox_proxy_service
sandbox_proxy_service = service


@sandbox_read_router.post("/execute")
@sandbox_proxy_router.post("/execute")
@handle_exceptions(error_message="execute command failed")
async def execute(command: SandboxCommand) -> RockResponse[CommandResponse]:
return RockResponse(result=await sandbox_read_service.execute(command.transform()))
return RockResponse(result=await sandbox_proxy_service.execute(command.transform()))


@sandbox_read_router.post("/create_session")
@sandbox_proxy_router.post("/create_session")
@handle_exceptions(error_message="create session failed")
async def create_session(request: SandboxCreateBashSessionRequest) -> RockResponse[CreateBashSessionResponse]:
return RockResponse(result=await sandbox_read_service.create_session(request.transform()))
return RockResponse(result=await sandbox_proxy_service.create_session(request.transform()))


@sandbox_read_router.post("/run_in_session")
@sandbox_proxy_router.post("/run_in_session")
@handle_exceptions(error_message="run in session failed")
async def run(action: SandboxBashAction) -> RockResponse[BashObservation]:
result = await sandbox_read_service.run_in_session(action.transform())
result = await sandbox_proxy_service.run_in_session(action.transform())
if result.exit_code is not None and result.exit_code == -1:
return RockResponse(status=ResponseStatus.FAILED, error=result.failure_reason)
return RockResponse(result=result)


@sandbox_read_router.post("/close_session")
@sandbox_proxy_router.post("/close_session")
@handle_exceptions(error_message="close session failed")
async def close_session(request: SandboxCloseBashSessionRequest) -> RockResponse[CloseBashSessionResponse]:
return RockResponse(result=await sandbox_read_service.close_session(request.transform()))
return RockResponse(result=await sandbox_proxy_service.close_session(request.transform()))


@sandbox_read_router.get("/is_alive")
@sandbox_proxy_router.get("/is_alive")
@handle_exceptions(error_message="get sandbox is alive failed")
async def is_alive(sandbox_id: str):
return RockResponse(result=await sandbox_read_service.is_alive(sandbox_id))
return RockResponse(result=await sandbox_proxy_service.is_alive(sandbox_id))


@sandbox_read_router.post("/read_file")
@sandbox_proxy_router.post("/read_file")
@handle_exceptions(error_message="read file failed")
async def read_file(request: SandboxReadFileRequest) -> RockResponse[ReadFileResponse]:
return RockResponse(result=await sandbox_read_service.read_file(request.transform()))
return RockResponse(result=await sandbox_proxy_service.read_file(request.transform()))


@sandbox_read_router.post("/write_file")
@sandbox_proxy_router.post("/write_file")
@handle_exceptions(error_message="write file failed")
async def write_file(request: SandboxWriteFileRequest) -> RockResponse[WriteFileResponse]:
return RockResponse(result=await sandbox_read_service.write_file(request.transform()))
return RockResponse(result=await sandbox_proxy_service.write_file(request.transform()))


@sandbox_read_router.post("/upload")
@sandbox_proxy_router.post("/upload")
@handle_exceptions(error_message="upload file failed")
async def upload(
file: UploadFile = File(...),
target_path: str = Form(...),
sandbox_id: str | None = Form(None),
) -> RockResponse[UploadResponse]:
return RockResponse(result=await sandbox_read_service.upload(file, target_path, sandbox_id))
return RockResponse(result=await sandbox_proxy_service.upload(file, target_path, sandbox_id))


@sandbox_read_router.websocket("/sandboxes/{id}/proxy/ws")
@sandbox_read_router.websocket("/sandboxes/{id}/proxy/ws/{path:path}")
@sandbox_proxy_router.websocket("/sandboxes/{id}/proxy/ws")
@sandbox_proxy_router.websocket("/sandboxes/{id}/proxy/ws/{path:path}")
async def websocket_proxy(websocket: WebSocket, id: str, path: str = ""):
await websocket.accept()
sandbox_id = id
logging.info(f"Client connected to WebSocket proxy: {sandbox_id}, path: {path}")
try:
await sandbox_read_service.websocket_proxy(websocket, sandbox_id, path)
await sandbox_proxy_service.websocket_proxy(websocket, sandbox_id, path)
except WebSocketDisconnect:
logging.info(f"Client disconnected from WebSocket proxy: {sandbox_id}")
except Exception as e:
logging.error(f"WebSocket proxy error: {e}")
await websocket.close(code=1011, reason=f"Proxy error: {str(e)}")


@sandbox_read_router.get("/get_token")
@sandbox_proxy_router.get("/get_token")
@handle_exceptions(error_message="get oss sts token failed")
async def get_token():
result = await asyncio.to_thread(sandbox_read_service.gen_oss_sts_token)
result = await asyncio.to_thread(sandbox_proxy_service.gen_oss_sts_token)
return RockResponse(result=result)
16 changes: 8 additions & 8 deletions rock/admin/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,20 @@
from rock import env_vars
from rock.admin.core.ray_service import RayService
from rock.admin.entrypoints.sandbox_api import sandbox_router, set_sandbox_manager
from rock.admin.entrypoints.sandbox_read_api import sandbox_read_router, set_sandbox_read_service
from rock.admin.entrypoints.sandbox_proxy_api import sandbox_proxy_router, set_sandbox_proxy_service
from rock.admin.entrypoints.warmup_api import set_warmup_service, warmup_router
from rock.admin.gem.api import gem_router, set_env_service
from rock.config import RockConfig
from rock.logger import init_logger
from rock.sandbox.gem_manager import GemManager
from rock.sandbox.service.sandbox_read_service import SandboxReadService
from rock.sandbox.service.sandbox_proxy_service import SandboxProxyService
from rock.sandbox.service.warmup_service import WarmupService
from rock.utils import sandbox_id_ctx_var
from rock.utils.providers import RedisProvider

parser = argparse.ArgumentParser()
parser.add_argument("--env", type=str, default="local")
parser.add_argument("--role", type=str, default="write", choices=["write", "read"])
parser.add_argument("--role", type=str, default="admin", choices=["admin", "proxy"])
parser.add_argument("--port", type=int, default=8080)

args = parser.parse_args()
Expand Down Expand Up @@ -58,7 +58,7 @@ async def lifespan(app: FastAPI):
await redis_provider.init_pool()

# init sandbox service
if args.role == "write":
if args.role == "admin":
# init service
if rock_config.runtime.enable_auto_clear:
sandbox_manager = GemManager(
Expand All @@ -82,8 +82,8 @@ async def lifespan(app: FastAPI):

RayService(rock_config.ray).init()
else:
sandbox_manager = SandboxReadService(rock_config=rock_config, redis_provider=redis_provider)
set_sandbox_read_service(sandbox_manager)
sandbox_manager = SandboxProxyService(rock_config=rock_config, redis_provider=redis_provider)
set_sandbox_proxy_service(sandbox_manager)

logger.info("rock-admin start")

Expand Down Expand Up @@ -163,10 +163,10 @@ async def log_requests_and_responses(request: Request, call_next):

def main():
# config router
if args.role == "write":
if args.role == "admin":
app.include_router(sandbox_router, prefix="/apis/envs/sandbox/v1", tags=["sandbox"])
else:
app.include_router(sandbox_read_router, prefix="/apis/envs/sandbox/v1", tags=["sandbox"])
app.include_router(sandbox_proxy_router, prefix="/apis/envs/sandbox/v1", tags=["sandbox"])
app.include_router(warmup_router, prefix="/apis/envs/sandbox/v1", tags=["warmup"])
app.include_router(gem_router, prefix="/apis/v1/envs/gem", tags=["gem"])

Expand Down
10 changes: 10 additions & 0 deletions rock/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ class OssConfig:
role_arn: str = ""


@dataclass
class ProxyServiceConfig:
timeout: float = 180.0
max_connections: int = 500
max_keepalive_connections: int = 100


@dataclass
class DatabaseConfig:
url: str = ""
Expand Down Expand Up @@ -82,6 +89,7 @@ class RockConfig:
sandbox_config: SandboxConfig = field(default_factory=SandboxConfig)
oss: OssConfig = field(default_factory=OssConfig)
runtime: RuntimeConfig = field(default_factory=RuntimeConfig)
proxy_service: ProxyServiceConfig = field(default_factory=ProxyServiceConfig)
nacos_provider: NacosConfigProvider | None = None

@classmethod
Expand Down Expand Up @@ -117,6 +125,8 @@ def from_env(cls, config_path: str | None = None):
kwargs["oss"] = OssConfig(**config["oss"])
if "runtime" in config:
kwargs["runtime"] = RuntimeConfig(**config["runtime"])
if "proxy_service" in config:
kwargs["proxy_service"] = ProxyServiceConfig(**config["proxy_service"])

return cls(**kwargs)

Expand Down
18 changes: 9 additions & 9 deletions rock/sandbox/sandbox_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
UploadResponse,
WriteFileResponse,
)
from rock.actions.sandbox.request import Action
from rock.admin.core.redis_key import ALIVE_PREFIX, alive_sandbox_key, timeout_sandbox_key
from rock.admin.metrics.decorator import monitor_sandbox_operation
from rock.admin.proto.response import SandboxStartResponse, SandboxStatusResponse
Expand All @@ -24,7 +25,6 @@
from rock.logger import init_logger
from rock.rocklet import __version__ as swe_version
from rock.rocklet.proto.request import BashInterruptAction
from rock.rocklet.proto.request import InternalBashAction as BashAction
from rock.rocklet.proto.request import InternalCloseBashSessionRequest as CloseBashSessionRequest
from rock.rocklet.proto.request import InternalCommand as Command
from rock.rocklet.proto.request import InternalCreateSessionRequest as CreateSessionRequest
Expand Down Expand Up @@ -231,13 +231,6 @@ async def get_status(self, sandbox_id) -> SandboxStatusResponse:
experiment_id=experiment_id,
)

async def execute(self, command: Command) -> CommandResponse:
sandbox_actor = await self.async_ray_get_actor(command.container_name)
if sandbox_actor is None:
raise Exception(f"sandbox {command.container_name} not found to execute")
await self._update_expire_time(command.container_name)
return await self.async_ray_get(sandbox_actor.execute.remote(command))

async def create_session(self, request: CreateSessionRequest) -> CreateBashSessionResponse:
sandbox_actor = await self.async_ray_get_actor(request.container_name)
if sandbox_actor is None:
Expand All @@ -246,7 +239,7 @@ async def create_session(self, request: CreateSessionRequest) -> CreateBashSessi
return await self.async_ray_get(sandbox_actor.create_session.remote(request))

@monitor_sandbox_operation()
async def run_in_session(self, action: BashAction | BashInterruptAction) -> BashObservation:
async def run_in_session(self, action: Action | BashInterruptAction) -> BashObservation:
sandbox_actor = await self.async_ray_get_actor(action.container_name)
if sandbox_actor is None:
raise Exception(f"sandbox {action.container_name} not found to run in session")
Expand All @@ -260,6 +253,13 @@ async def close_session(self, request: CloseBashSessionRequest) -> CloseBashSess
await self._update_expire_time(request.container_name)
return await self.async_ray_get(sandbox_actor.close_session.remote(request))

async def execute(self, command: Command) -> CommandResponse:
sandbox_actor = await self.async_ray_get_actor(command.container_name)
if sandbox_actor is None:
raise Exception(f"sandbox {command.container_name} not found to execute")
await self._update_expire_time(command.container_name)
return await self.async_ray_get(sandbox_actor.execute.remote(command))

async def read_file(self, request: ReadFileRequest) -> ReadFileResponse:
sandbox_actor = await self.async_ray_get_actor(request.container_name)
if sandbox_actor is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from rock.admin.core.redis_key import alive_sandbox_key, timeout_sandbox_key
from rock.admin.metrics.decorator import monitor_sandbox_operation
from rock.admin.metrics.monitor import MetricsMonitor
from rock.config import OssConfig, RockConfig
from rock.config import OssConfig, ProxyServiceConfig, RockConfig
from rock.deployments.constants import Port
from rock.deployments.status import ServiceStatus
from rock.logger import init_logger
Expand All @@ -39,14 +39,25 @@
logger = init_logger(__name__)


class SandboxReadService:
class SandboxProxyService:
_redis_provider: RedisProvider = None
_httpx_client = httpx.AsyncClient(timeout=180.0)
_httpx_client = None

def __init__(self, rock_config: RockConfig, redis_provider: RedisProvider | None = None):
self._redis_provider = redis_provider
self.metrics_monitor = MetricsMonitor.create()
self.oss_config: OssConfig = rock_config.oss
self.proxy_config: ProxyServiceConfig = rock_config.proxy_service
logger.info(f"proxy config: {self.proxy_config}")
# Initialize httpx client with configuration
self._httpx_client = httpx.AsyncClient(
timeout=self.proxy_config.timeout,
limits=httpx.Limits(
max_connections=self.proxy_config.max_connections,
max_keepalive_connections=self.proxy_config.max_keepalive_connections,
),
)

self.sts_client = client.AcsClient(
self.oss_config.access_key_id,
self.oss_config.access_key_secret,
Expand Down
Loading