diff --git a/rock/__init__.py b/rock/__init__.py index 87865c24..6002fb82 100644 --- a/rock/__init__.py +++ b/rock/__init__.py @@ -1,3 +1,20 @@ from rock.sdk.envs import make -__all__ = ["make"] +from ._codes import codes +from .sdk.common.exceptions import ( + BadRequestRockError, + CommandRockError, + InternalServerRockError, + RockException, + raise_for_code, +) + +__all__ = [ + "make", + "codes", + "RockException", + "BadRequestRockError", + "InternalServerRockError", + "CommandRockError", + "raise_for_code", +] diff --git a/rock/_codes.py b/rock/_codes.py new file mode 100644 index 00000000..832db96d --- /dev/null +++ b/rock/_codes.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +from enum import IntEnum + +__all__ = ["codes"] + + +class codes(IntEnum): + """ + ROCK status codes enumeration. + + This class extends IntEnum to provide status codes with associated phrase descriptions. + Each enum member has both an integer value and a phrase attribute for human-readable descriptions. + + The class also provides utility methods to categorize codes and retrieve phrases. + """ + + _ignore_ = ["phrase"] + phrase: str = "" + + def __new__(cls, value: int, phrase: str = "") -> codes: + """ + Create a new codes enum member with both value and phrase. + + Args: + value: The integer status code value + phrase: Human-readable description of the status + + Returns: + A new codes enum member with the phrase attribute set + """ + obj = int.__new__(cls, value) + obj._value_ = value + obj.phrase = phrase # Add phrase as an instance attribute + return obj + + def __str__(self) -> str: + """Return string representation of the status code value.""" + return str(self.value) + + @classmethod + def get_reason_phrase(cls, value: int) -> str: + """ + Get the reason phrase for a given status code value. + + Args: + value: The integer status code value to look up + + Returns: + The reason phrase string, or empty string if code not found + + Example: + >>> codes.get_reason_phrase(2000) + 'OK' + >>> codes.get_reason_phrase(9999) + '' + """ + try: + return codes(value).phrase + except ValueError: + return "" + + @classmethod + def is_success(cls, value: int) -> bool: + """ + Check if a status code indicates success (2xxx range). + + Args: + value: The status code to check + + Returns: + True if the code is in the 2000-2999 range, False otherwise + """ + return 2000 <= value <= 2999 + + @classmethod + def is_client_error(cls, value: int) -> bool: + """ + Check if a status code indicates a client error (4xxx range). + + Args: + value: The status code to check + + Returns: + True if the code is in the 4000-4999 range, False otherwise + """ + return 4000 <= value <= 4999 + + @classmethod + def is_server_error(cls, value: int) -> bool: + """ + Check if a status code indicates a server error (5xxx range). + + Args: + value: The status code to check + + Returns: + True if the code is in the 5000-5999 range, False otherwise + """ + return 5000 <= value <= 5999 + + @classmethod + def is_command_error(cls, value: int) -> bool: + """ + Check if a status code indicates a command error (6xxx range). + + Args: + value: The status code to check + + Returns: + True if the code is in the 6000-6999 range, False otherwise + """ + return 6000 <= value <= 6999 + + @classmethod + def is_error(cls, value: int) -> bool: + """ + Check if a status code indicates any kind of error (4xxx or 5xxx range). + + Args: + value: The status code to check + + Returns: + True if the code is in the 4000-5999 range, False otherwise + """ + return 4000 <= value <= 6999 + + OK = 2000, "OK" + """ + Success codes (2xxx) + """ + + BAD_REQUEST = 4000, "Bad Request" + """ + Client error codes (4xxx): + + These errors indicate issues with the client request, + SDK will raise Exceptions for these errors. + """ + + INTERNAL_SERVER_ERROR = 5000, "Internal Server Error" + """ + Server error codes (5xxx): + + These errors indicate issues on the server side, + SDK will raise Exceptions for these errors. + """ + + COMMAND_ERROR = 6000, "Command Error" + """ + Command/execution error codes (6xxx): + + These errors are related to command execution and should be handled by the model, + SDK will NOT raise Exceptions for these errors. + """ + + +# Include lower-case styles for `requests` compatibility. +for code in codes: + setattr(codes, code._name_.lower(), int(code)) diff --git a/rock/sdk/common/exceptions.py b/rock/sdk/common/exceptions.py new file mode 100644 index 00000000..402d8e38 --- /dev/null +++ b/rock/sdk/common/exceptions.py @@ -0,0 +1,54 @@ +import rock +from rock.actions.response import RockResponse +from rock.utils.deprecated import deprecated + + +class RockException(Exception): + _code: rock.codes = None + + def __init__(self, message, code: rock.codes = None): + super().__init__(message) + self._code = code + + @property + def code(self): + return self._code + + +@deprecated("This exception is deprecated") +class InvalidParameterRockException(RockException): + def __init__(self, message): + super().__init__(message) + + +class BadRequestRockError(RockException): + def __init__(self, message, code: rock.codes = rock.codes.BAD_REQUEST): + super().__init__(message, code) + + +class InternalServerRockError(RockException): + def __init__(self, message, code: rock.codes = rock.codes.INTERNAL_SERVER_ERROR): + super().__init__(message, code) + + +class CommandRockError(RockException): + def __init__(self, message, code: rock.codes = rock.codes.COMMAND_ERROR): + super().__init__(message, code) + + +def raise_for_code(code: rock.codes, message: str): + if code is None or rock.codes.is_success(code): + return + + if rock.codes.is_client_error(code): + raise BadRequestRockError(message) + if rock.codes.is_server_error(code): + raise InternalServerRockError(message) + if rock.codes.is_command_error(code): + raise CommandRockError(message) + + raise RockException(message, code=code) + + +def from_rock_exception(e: RockException) -> RockResponse: + return RockResponse(code=e.code, failure_reason=str(e)) diff --git a/rock/sdk/sandbox/client.py b/rock/sdk/sandbox/client.py index 6423a9f1..71eeed6c 100644 --- a/rock/sdk/sandbox/client.py +++ b/rock/sdk/sandbox/client.py @@ -6,6 +6,7 @@ import uuid import warnings from datetime import datetime, timedelta, timezone +from enum import Enum from pathlib import Path import oss2 @@ -37,6 +38,7 @@ WriteFileResponse, ) from rock.sdk.common.constants import PID_PREFIX, PID_SUFFIX, RunModeType +from rock.sdk.common.exceptions import InvalidParameterRockException from rock.sdk.sandbox.agent.base import Agent from rock.sdk.sandbox.config import SandboxConfig, SandboxGroupConfig from rock.sdk.sandbox.remote_user import LinuxRemoteUser, RemoteUser @@ -45,6 +47,11 @@ logger = logging.getLogger(__name__) +class RunMode(str, Enum): + NORMAL = "normal" + NOHUP = "nohup" + + class Sandbox(AbstractSandbox): config: SandboxConfig _url: str @@ -143,7 +150,7 @@ async def start(self): except Exception as e: logging.warning(f"Failed to get status, {str(e)}") await asyncio.sleep(3) - raise Exception(f"Failed to start sandbox within {self.config.startup_timeout}s") + raise Exception(f"Failed to start sandbox within {self.config.startup_timeout}s, sandbox: {str(self)}") async def is_alive(self) -> IsAliveResponse: try: @@ -173,7 +180,7 @@ async def execute(self, command: Command) -> CommandResponse: "sandbox_id": self.sandbox_id, "timeout": command.timeout, "cwd": command.cwd, - "env": command.env + "env": command.env, } try: response = await HttpUtils.post(url, headers, data) @@ -315,11 +322,59 @@ async def arun( session: str = None, wait_timeout=300, wait_interval=10, - mode: RunModeType = "normal", + mode: RunModeType = RunMode.NORMAL, response_limited_bytes_in_nohup: int | None = None, ignore_output: bool = False, ) -> Observation: - if mode == "nohup": + """ + Asynchronously run a command in the sandbox environment. + This method supports two execution modes: + - NORMAL: Execute command synchronously and wait for completion + - NOHUP: Execute command in background using nohup, suitable for long-running tasks + Args: + cmd (str): The command to execute in the sandbox + session (str, optional): The session identifier to run the command in. + If None, a temporary session will be created for nohup mode. Defaults to None. + wait_timeout (int, optional): Maximum time in seconds to wait for nohup command completion. + Defaults to 300. + wait_interval (int, optional): Interval in seconds between process completion checks for nohup mode. + Minimum value is 5 seconds. Defaults to 10. + mode (RunModeType, optional): Execution mode - either "normal" or "nohup". + Defaults to RunMode.NORMAL. + response_limited_bytes_in_nohup (int | None, optional): Maximum bytes to read from nohup output file. + If None, reads entire output. Only applies to nohup mode. Defaults to None. + nohup_command_timeout (int, optional): Timeout in seconds for the nohup command submission itself. + Defaults to 60. + Returns: + Observation: Command execution result containing output, exit code, and failure reason if any. + - For normal mode: Returns immediate execution result + - For nohup mode: Returns result after process completion or timeout + Raises: + InvalidParameterRockException: If an unsupported run mode is provided + ReadTimeout: If command execution times out (nohup mode) + Exception: For other execution failures in nohup mode + Examples: + # Normal synchronous execution + result = await sandbox.arun("ls -la") + # Background execution with nohup + result = await sandbox.arun( + "python long_running_script.py", + mode="nohup", + wait_timeout=600 + ) + # Limited output reading in nohup mode + result = await sandbox.arun( + "generate_large_output.sh", + mode="nohup", + response_limited_bytes_in_nohup=1024 + ) + """ + if mode not in (RunMode.NORMAL, RunMode.NOHUP): + raise InvalidParameterRockException(f"Unsupported arun mode: {mode}") + + if mode == RunMode.NORMAL: + return await self._run_in_session(action=Action(command=cmd, session=session)) + if mode == RunMode.NOHUP: try: timestamp = str(time.time_ns()) if session is None: @@ -327,7 +382,9 @@ async def arun( await self.create_session(CreateBashSessionRequest(session=temp_session)) session = temp_session tmp_file = f"/tmp/tmp_{timestamp}.out" - nohup_command = f"nohup {cmd} < /dev/null > {tmp_file} 2>&1 & echo {PID_PREFIX}${{!}}{PID_SUFFIX};disown" + nohup_command = ( + f"nohup {cmd} < /dev/null > {tmp_file} 2>&1 & echo {PID_PREFIX}${{!}}{PID_SUFFIX};disown" + ) # todo: # Theoretically, the nohup command should return in a very short time, but the total time online is longer, # so time_out is set larger to avoid affecting online usage. It will be reduced after optimizing the read cluster time. @@ -354,7 +411,9 @@ async def arun( file_size = None try: size_result: Observation = await self._run_in_session( - BashAction(session=session, command=f"stat -c %s {tmp_file} 2>/dev/null || stat -f %z {tmp_file}") + BashAction( + session=session, command=f"stat -c %s {tmp_file} 2>/dev/null || stat -f %z {tmp_file}" + ) ) if size_result.exit_code == 0 and size_result.output.strip().isdigit(): file_size = int(size_result.output.strip()) @@ -382,10 +441,6 @@ async def arun( except Exception as e: error_msg = f"Failed to execute nohup command '{cmd}': {str(e)}" return Observation(output=error_msg, exit_code=1, failure_reason=error_msg) - elif mode == "normal": - return await self._run_in_session(action=BashAction(command=cmd, session=session)) - else: - return Observation(output="", exit_code=1, failure_reason="Unsupported arun mode") async def write_file(self, request: WriteFileRequest) -> WriteFileResponse: content = request.content diff --git a/rock/utils/deprecated.py b/rock/utils/deprecated.py new file mode 100644 index 00000000..9624ae34 --- /dev/null +++ b/rock/utils/deprecated.py @@ -0,0 +1,26 @@ +import warnings +from collections.abc import Callable +from functools import wraps +from typing import Any + + +def deprecated(reason: str = "") -> Callable: + """ + Decorator to mark a function or class as deprecated. + + Args: + reason: Optional reason for deprecation + + Returns: + Decorated function or class + """ + + def decorator(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + warnings.warn(f"{func.__name__} is deprecated. {reason}", DeprecationWarning, stacklevel=2) + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/tests/unit/common/test_exceptions.py b/tests/unit/common/test_exceptions.py new file mode 100644 index 00000000..a9f41f41 --- /dev/null +++ b/tests/unit/common/test_exceptions.py @@ -0,0 +1,25 @@ +from rock import RockException, codes + + +class TestRockException: + """Test cases for the RockException class.""" + + def test_rock_exception_basic_creation(self): + """Test basic creation of RockException with message only.""" + message = "Test error message" + exception = RockException(message) + + assert str(exception) == message + assert exception.code is None + assert isinstance(exception, Exception) + + def test_rock_exception_with_code(self): + """Test RockException creation with both message and code.""" + message = "Test error with code" + code = codes.BAD_REQUEST + exception = RockException(message, code) + + assert str(exception) == message + assert exception.code == code + assert exception.code == 4000 + assert exception.code.phrase == "Bad Request" diff --git a/tests/unit/test_codes.py b/tests/unit/test_codes.py new file mode 100644 index 00000000..5a41c19b --- /dev/null +++ b/tests/unit/test_codes.py @@ -0,0 +1,166 @@ +import logging + +import rock +from rock._codes import codes + +logger = logging.getLogger(__name__) + + +def test_codes_values(): + """测试基本状态码值""" + assert rock.codes.OK == 2000 + assert rock.codes.BAD_REQUEST == 4000 + assert rock.codes.INTERNAL_SERVER_ERROR == 5000 + assert rock.codes.COMMAND_ERROR == 6000 + logger.info(f"rock.codes.OK.phrase: {rock.codes.OK.phrase}") + logger.info(f"rock.codes.BAD_REQUEST: {rock.codes.BAD_REQUEST}") + + +def test_codes_phrases(): + """测试状态码的phrase属性""" + assert rock.codes.OK.phrase == "OK" + assert rock.codes.BAD_REQUEST.phrase == "Bad Request" + assert rock.codes.INTERNAL_SERVER_ERROR.phrase == "Internal Server Error" + assert rock.codes.COMMAND_ERROR.phrase == "Command Error" + + +def test_codes_string_representation(): + """测试状态码的字符串表示""" + assert str(rock.codes.OK) == "2000" + assert str(rock.codes.BAD_REQUEST) == "4000" + assert str(rock.codes.INTERNAL_SERVER_ERROR) == "5000" + assert str(rock.codes.COMMAND_ERROR) == "6000" + + +def test_get_reason_phrase(): + """测试get_reason_phrase方法""" + assert codes.get_reason_phrase(2000) == "OK" + assert codes.get_reason_phrase(4000) == "Bad Request" + assert codes.get_reason_phrase(5000) == "Internal Server Error" + assert codes.get_reason_phrase(6000) == "Command Error" + + # 测试不存在的状态码 + assert codes.get_reason_phrase(9999) == "" + assert codes.get_reason_phrase(1000) == "" + + +def test_is_success(): + """测试is_success方法""" + assert codes.is_success(2000) is True + assert codes.is_success(2001) is True + assert codes.is_success(2999) is True + + assert codes.is_success(1999) is False + assert codes.is_success(3000) is False + assert codes.is_success(4000) is False + assert codes.is_success(5000) is False + + +def test_is_client_error(): + """测试is_client_error方法""" + assert codes.is_client_error(4000) is True + assert codes.is_client_error(4001) is True + assert codes.is_client_error(4999) is True + + assert codes.is_client_error(3999) is False + assert codes.is_client_error(5000) is False + assert codes.is_client_error(2000) is False + + +def test_is_server_error(): + """测试is_server_error方法""" + assert codes.is_server_error(5000) is True + assert codes.is_server_error(5001) is True + assert codes.is_server_error(5999) is True + + assert codes.is_server_error(4999) is False + assert codes.is_server_error(6000) is False + assert codes.is_server_error(2000) is False + + +def test_is_command_error(): + """测试is_command_error方法""" + assert codes.is_command_error(6000) is True + assert codes.is_command_error(6001) is True + assert codes.is_command_error(6999) is True + + assert codes.is_command_error(5999) is False + assert codes.is_command_error(7000) is False + assert codes.is_command_error(2000) is False + + +def test_is_error(): + """测试is_error方法""" + # 客户端错误 + assert codes.is_error(4000) is True + assert codes.is_error(4999) is True + + # 服务器错误 + assert codes.is_error(5000) is True + assert codes.is_error(5999) is True + + # 命令错误 + assert codes.is_error(6000) is True + assert codes.is_error(6999) is True + + # 非错误状态 + assert codes.is_error(2000) is False + assert codes.is_error(3000) is False + assert codes.is_error(3999) is False + assert codes.is_error(7000) is False + + +def test_lowercase_compatibility(): + """测试小写属性兼容性""" + assert hasattr(codes, "ok") + assert hasattr(codes, "bad_request") + assert hasattr(codes, "internal_server_error") + assert hasattr(codes, "command_error") + + assert codes.ok == 2000 + assert codes.bad_request == 4000 + assert codes.internal_server_error == 5000 + assert codes.command_error == 6000 + + +def test_enum_behavior(): + """测试枚举行为""" + # 测试枚举成员比较 + assert rock.codes.OK == codes.OK + assert rock.codes.BAD_REQUEST != rock.codes.OK + + # 测试枚举成员类型 + assert isinstance(rock.codes.OK, codes) + assert isinstance(rock.codes.OK, int) + + # 测试枚举迭代 + all_codes = list(codes) + assert len(all_codes) == 4 + assert codes.OK in all_codes + assert codes.BAD_REQUEST in all_codes + assert codes.INTERNAL_SERVER_ERROR in all_codes + assert codes.COMMAND_ERROR in all_codes + + +def test_boundary_values(): + """测试边界值""" + # 测试范围边界 + assert codes.is_success(2000) is True + assert codes.is_success(2999) is True + assert codes.is_success(1999) is False + assert codes.is_success(3000) is False + + assert codes.is_client_error(4000) is True + assert codes.is_client_error(4999) is True + assert codes.is_client_error(3999) is False + assert codes.is_client_error(5000) is False + + assert codes.is_server_error(5000) is True + assert codes.is_server_error(5999) is True + assert codes.is_server_error(4999) is False + assert codes.is_server_error(6000) is False + + assert codes.is_command_error(6000) is True + assert codes.is_command_error(6999) is True + assert codes.is_command_error(5999) is False + assert codes.is_command_error(7000) is False