diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..5c6f849 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,45 @@ +name: CI + +on: + push: + branches: [master, main] + pull_request: + branches: [master, main] + +jobs: + test: + runs-on: ubuntu-latest + defaults: + run: + working-directory: backend + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: pip + cache-dependency-path: backend/requirements.txt + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + - name: Lint + run: ruff check app tests + + - name: Test + env: + AUTH_DISABLED: "true" + DATA_DIR: ./data + run: pytest -q + + docker: + runs-on: ubuntu-latest + needs: test + steps: + - uses: actions/checkout@v4 + - name: Build image + run: docker build -t chainsentinel-api:ci backend diff --git a/.gitignore b/.gitignore index a9d44c1..11890b9 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,9 @@ dist/ build/ token_usage.jsonl .venv/ +.pytest_cache/ +.ruff_cache/ + +# Runtime data dir (token_usage.jsonl persistence) +backend/data/ +data/ diff --git a/backend/.env.example b/backend/.env.example index 5a75257..cd63797 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -4,6 +4,30 @@ MIMO_BASE_URL=https://api.xiaomimimo.com/v1 MIMO_MODEL=mimo-v2.5-pro # Application -APP_ENV=production +APP_ENV=development APP_PORT=8000 LOG_LEVEL=info + +# Security — REQUIRED in production +# Comma-separated list. In dev, leaving API_KEYS empty allows requests with a warning. +API_KEYS= +AUTH_DISABLED=false +# Comma-separated allowed origins. Use specific origins; "*" disables credentials. +ALLOWED_ORIGINS=http://localhost:3000 + +# Rate limits (slowapi format: N/period) +RATE_LIMIT_ANALYZE=10/minute +RATE_LIMIT_BATCH=2/minute +RATE_LIMIT_CHAT=30/minute +RATE_LIMIT_STATS=60/minute + +# Token budget +DAILY_TOKEN_BUDGET=10000000 +BUDGET_ENFORCE=true + +# Upload limits +MAX_CONTRACT_SIZE_KB=500 +MAX_CONTRACT_LOC=5000 + +# Storage +DATA_DIR=./data diff --git a/backend/Dockerfile b/backend/Dockerfile new file mode 100644 index 0000000..06ef5cd --- /dev/null +++ b/backend/Dockerfile @@ -0,0 +1,33 @@ +# syntax=docker/dockerfile:1.6 +FROM python:3.12-slim AS base + +ENV PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 \ + PIP_NO_CACHE_DIR=1 \ + PIP_DISABLE_PIP_VERSION_CHECK=1 + +WORKDIR /app + +# System deps +RUN apt-get update \ + && apt-get install -y --no-install-recommends curl ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY app ./app + +# Non-root user +RUN useradd --uid 10001 --create-home --shell /bin/bash chainsentinel \ + && mkdir -p /app/data \ + && chown -R chainsentinel:chainsentinel /app +USER chainsentinel + +ENV DATA_DIR=/app/data +EXPOSE 8000 + +HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \ + CMD curl -fsS http://127.0.0.1:8000/api/health || exit 1 + +CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/backend/app/__init__.py b/backend/app/__init__.py index 267cb91..8b9774a 100644 --- a/backend/app/__init__.py +++ b/backend/app/__init__.py @@ -1,6 +1,6 @@ # ChainSentinel Backend -from app.main import app from app.core.config import settings +from app.main import app __all__ = ["app", "settings"] diff --git a/backend/app/api/routes.py b/backend/app/api/routes.py index af667c5..47095df 100644 --- a/backend/app/api/routes.py +++ b/backend/app/api/routes.py @@ -1,35 +1,48 @@ """API Routes — Contract analysis, batch processing, and stats.""" +from __future__ import annotations + import asyncio import time -from typing import Optional - -from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Request -from pydantic import BaseModel +from fastapi import ( + APIRouter, + Depends, + File, + Form, + HTTPException, + Request, + UploadFile, +) +from fastapi.responses import StreamingResponse +from pydantic import BaseModel, Field + +from app.core.config import settings +from app.core.logging import logger +from app.core.security import verify_api_key +from app.core.token_tracker import BudgetExceededError, TokenTracker from app.services.analysis_pipeline import AnalysisPipeline from app.services.mimo_client import MiMoClient -from app.core.token_tracker import TokenTracker -from app.utils.contract_validator import validate_solidity, extract_contract_name +from app.utils.contract_validator import extract_contract_name, validate_solidity router = APIRouter() class AnalyzeRequest(BaseModel): - code: str - contract_name: Optional[str] = None - network: Optional[str] = "ethereum" + code: str = Field(..., min_length=1) + contract_name: str | None = None + network: str | None = "ethereum" class BatchAnalyzeRequest(BaseModel): - contracts: list[AnalyzeRequest] + contracts: list[AnalyzeRequest] = Field(..., min_length=1, max_length=10) parallel: bool = True - max_concurrent: int = 3 + max_concurrent: int = Field(default=3, ge=1, le=5) class ChatRequest(BaseModel): - message: str - context: Optional[str] = None + message: str = Field(..., min_length=1, max_length=10_000) + context: str | None = None def _get_pipeline(request: Request) -> AnalysisPipeline: @@ -38,35 +51,65 @@ def _get_pipeline(request: Request) -> AnalysisPipeline: return AnalysisPipeline(mimo, tracker) -@router.post("/analyze") -async def analyze_contract(req: AnalyzeRequest, request: Request): - """Analyze a single smart contract with multi-agent pipeline.""" - if not req.code.strip(): +def _check_size(code: str) -> None: + """Reject contracts that exceed configured limits.""" + size_kb = len(code.encode("utf-8")) / 1024 + if size_kb > settings.MAX_CONTRACT_SIZE_KB: + raise HTTPException( + 413, + f"Contract too large: {size_kb:.1f} KB > {settings.MAX_CONTRACT_SIZE_KB} KB", + ) + loc = code.count("\n") + 1 + if loc > settings.MAX_CONTRACT_LOC: + raise HTTPException( + 413, + f"Contract too long: {loc} lines > {settings.MAX_CONTRACT_LOC}", + ) + + +def _validate_contract(code: str) -> None: + if not code.strip(): raise HTTPException(400, "Contract code is empty") - - validation = validate_solidity(req.code) + _check_size(code) + validation = validate_solidity(code) if not validation["valid"]: raise HTTPException(400, f"Invalid Solidity code: {validation['error']}") + +def _check_budget_or_raise(tracker: TokenTracker) -> None: + try: + tracker.check_budget() + except BudgetExceededError as e: + raise HTTPException(429, str(e)) from e + + +@router.post("/analyze", dependencies=[Depends(verify_api_key)]) +async def analyze_contract(req: AnalyzeRequest, request: Request): + """Analyze a single smart contract with multi-agent pipeline.""" + _validate_contract(req.code) + tracker: TokenTracker = request.app.state.token_tracker + _check_budget_or_raise(tracker) + contract_name = req.contract_name or extract_contract_name(req.code) pipeline = _get_pipeline(request) - - result = await pipeline.analyze(req.code, contract_name) - return result + return await pipeline.analyze(req.code, contract_name) -@router.post("/batch-analyze") +@router.post("/batch-analyze", dependencies=[Depends(verify_api_key)]) async def batch_analyze(req: BatchAnalyzeRequest, request: Request): """Analyze multiple contracts in batch for high-throughput auditing.""" - if len(req.contracts) > 10: - raise HTTPException(400, "Maximum 10 contracts per batch") + for c in req.contracts: + _validate_contract(c.code) + + tracker: TokenTracker = request.app.state.token_tracker + _check_budget_or_raise(tracker) pipeline = _get_pipeline(request) if req.parallel: semaphore = asyncio.Semaphore(req.max_concurrent) - async def limited_analyze(contract): + async def limited_analyze(contract: AnalyzeRequest): async with semaphore: name = contract.contract_name or extract_contract_name(contract.code) return await pipeline.analyze(contract.code, name) @@ -79,49 +122,66 @@ async def limited_analyze(contract): "total_contracts": len(req.contracts), "parallel": True, "results": [ - r if not isinstance(r, Exception) else {"error": str(r)} - for r in results + r if not isinstance(r, Exception) else {"error": str(r)} for r in results ], } - else: - results = [] - for contract in req.contracts: - name = contract.contract_name or extract_contract_name(contract.code) - result = await pipeline.analyze(contract.code, name) - results.append(result) - return { - "batch_id": str(int(time.time())), - "total_contracts": len(req.contracts), - "parallel": False, - "results": results, - } + results = [] + for contract in req.contracts: + name = contract.contract_name or extract_contract_name(contract.code) + results.append(await pipeline.analyze(contract.code, name)) + + return { + "batch_id": str(int(time.time())), + "total_contracts": len(req.contracts), + "parallel": False, + "results": results, + } -@router.post("/upload") +@router.post("/upload", dependencies=[Depends(verify_api_key)]) async def upload_contract( + request: Request, file: UploadFile = File(...), - contract_name: Optional[str] = Form(None), - request: Request = None, + contract_name: str | None = Form(None), ): """Upload a .sol file for analysis.""" - if not file.filename.endswith(".sol"): + if not file.filename or not file.filename.endswith(".sol"): raise HTTPException(400, "Only .sol files are supported") - content = await file.read() - code = content.decode("utf-8") + # Stream-read with hard cap + max_bytes = settings.MAX_CONTRACT_SIZE_KB * 1024 + chunks: list[bytes] = [] + total = 0 + while True: + chunk = await file.read(64 * 1024) + if not chunk: + break + total += len(chunk) + if total > max_bytes: + raise HTTPException(413, f"File exceeds {settings.MAX_CONTRACT_SIZE_KB} KB limit") + chunks.append(chunk) + + try: + code = b"".join(chunks).decode("utf-8") + except UnicodeDecodeError as e: + raise HTTPException(400, "File is not valid UTF-8 text") from e + + _validate_contract(code) + tracker: TokenTracker = request.app.state.token_tracker + _check_budget_or_raise(tracker) name = contract_name or file.filename.replace(".sol", "") pipeline = _get_pipeline(request) - result = await pipeline.analyze(code, name) - return result + return await pipeline.analyze(code, name) -@router.post("/chat") +@router.post("/chat", dependencies=[Depends(verify_api_key)]) async def chat_with_agent(req: ChatRequest, request: Request): """Chat with the security analysis agent for Q&A.""" mimo: MiMoClient = request.app.state.mimo_client tracker: TokenTracker = request.app.state.token_tracker + _check_budget_or_raise(tracker) system = ( "You are ChainSentinel AI, a smart contract security expert. " @@ -134,7 +194,6 @@ async def chat_with_agent(req: ChatRequest, request: Request): system=system, temperature=0.4, ) - tokens = result.get("tokens", {}).get("total", 0) tracker.record_usage(tokens, agent="chat_agent") @@ -145,22 +204,51 @@ async def chat_with_agent(req: ChatRequest, request: Request): } -@router.get("/stats") +@router.post("/chat/stream", dependencies=[Depends(verify_api_key)]) +async def chat_stream(req: ChatRequest, request: Request): + """Streaming variant of /chat — returns server-sent events as plain text chunks.""" + mimo: MiMoClient = request.app.state.mimo_client + tracker: TokenTracker = request.app.state.token_tracker + _check_budget_or_raise(tracker) + + system = ( + "You are ChainSentinel AI, a smart contract security expert. " + "Be concise and technical." + ) + + async def event_gen(): + try: + async for token in mimo.stream_chat( + messages=[{"role": "user", "content": req.message}], + system=system, + ): + yield token + except Exception as e: # noqa: BLE001 + logger.exception("Streaming failed") + yield f"\n[error] {e}" + finally: + # Streaming endpoint doesn't get usage info from upstream; record a small placeholder + tracker.record_usage(0, agent="chat_stream") + + return StreamingResponse(event_gen(), media_type="text/plain") + + +@router.get("/stats", dependencies=[Depends(verify_api_key)]) async def get_stats(request: Request): """Get token usage statistics.""" tracker: TokenTracker = request.app.state.token_tracker return tracker.get_stats() -@router.get("/stats/history") +@router.get("/stats/history", dependencies=[Depends(verify_api_key)]) async def get_stats_history(request: Request, limit: int = 50): """Get token usage history.""" tracker: TokenTracker = request.app.state.token_tracker - return {"history": tracker.get_history(limit)} + return {"history": tracker.get_history(min(max(limit, 1), 1000))} -@router.get("/stats/trend") +@router.get("/stats/trend", dependencies=[Depends(verify_api_key)]) async def get_stats_trend(request: Request, days: int = 7): """Get daily token usage trend.""" tracker: TokenTracker = request.app.state.token_tracker - return {"trend": tracker.get_daily_trend(days)} + return {"trend": tracker.get_daily_trend(min(max(days, 1), 90))} diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 8c8f2db..2344e3e 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -4,25 +4,67 @@ from dataclasses import dataclass, field +def _csv_env(name: str, default: str = "") -> list[str]: + """Parse comma-separated env var into a list of stripped, non-empty strings.""" + raw = os.getenv(name, default) + return [item.strip() for item in raw.split(",") if item.strip()] + + @dataclass class Settings: + # MiMo / LLM provider MIMO_API_KEY: str = os.getenv("MIMO_API_KEY", "") MIMO_BASE_URL: str = os.getenv("MIMO_BASE_URL", "https://api.xiaomimimo.com/v1") MIMO_MODEL: str = os.getenv("MIMO_MODEL", "mimo-v2.5-pro") + # Application APP_ENV: str = os.getenv("APP_ENV", "development") APP_PORT: int = int(os.getenv("APP_PORT", "8000")) LOG_LEVEL: str = os.getenv("LOG_LEVEL", "info") + # Security + # Comma-separated list of allowed origins. Use "*" only in dev. + ALLOWED_ORIGINS: list[str] = field( + default_factory=lambda: _csv_env("ALLOWED_ORIGINS", "http://localhost:3000") + ) + # Comma-separated API keys. If empty, auth is disabled (dev mode warning emitted). + API_KEYS: list[str] = field(default_factory=lambda: _csv_env("API_KEYS", "")) + # Skip auth header check entirely (development convenience). + AUTH_DISABLED: bool = os.getenv("AUTH_DISABLED", "false").lower() == "true" + + # Rate limiting + RATE_LIMIT_ANALYZE: str = os.getenv("RATE_LIMIT_ANALYZE", "10/minute") + RATE_LIMIT_BATCH: str = os.getenv("RATE_LIMIT_BATCH", "2/minute") + RATE_LIMIT_CHAT: str = os.getenv("RATE_LIMIT_CHAT", "30/minute") + RATE_LIMIT_STATS: str = os.getenv("RATE_LIMIT_STATS", "60/minute") + # Token management - DAILY_TOKEN_BUDGET: int = 10_000_000 # 10M tokens/day target - MAX_CONCURRENT_ANALYSES: int = 5 - MAX_CONTRACT_SIZE_KB: int = 500 + DAILY_TOKEN_BUDGET: int = int(os.getenv("DAILY_TOKEN_BUDGET", "10000000")) + BUDGET_ENFORCE: bool = os.getenv("BUDGET_ENFORCE", "true").lower() == "true" + MAX_CONCURRENT_ANALYSES: int = int(os.getenv("MAX_CONCURRENT_ANALYSES", "5")) + + # Upload limits + MAX_CONTRACT_SIZE_KB: int = int(os.getenv("MAX_CONTRACT_SIZE_KB", "500")) + MAX_CONTRACT_LOC: int = int(os.getenv("MAX_CONTRACT_LOC", "5000")) # Analysis pipeline config - AGENTS_PER_ANALYSIS: int = 4 # Number of AI agents per contract audit - CHUNK_SIZE_LINES: int = 200 # Lines per analysis chunk - OVERLAP_LINES: int = 20 # Overlap between chunks + AGENTS_PER_ANALYSIS: int = 4 + CHUNK_SIZE_LINES: int = int(os.getenv("CHUNK_SIZE_LINES", "200")) + OVERLAP_LINES: int = int(os.getenv("OVERLAP_LINES", "20")) + + # Storage paths (relative-friendly, container-friendly) + DATA_DIR: str = os.getenv( + "DATA_DIR", + os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "data")), + ) + + @property + def TOKEN_LOG_FILE(self) -> str: + return os.path.join(self.DATA_DIR, "token_usage.jsonl") + + @property + def is_production(self) -> bool: + return self.APP_ENV.lower() == "production" settings = Settings() diff --git a/backend/app/core/logging.py b/backend/app/core/logging.py new file mode 100644 index 0000000..aa1b9c0 --- /dev/null +++ b/backend/app/core/logging.py @@ -0,0 +1,35 @@ +"""Structured logging setup.""" + +import logging +import sys + +from app.core.config import settings + + +def setup_logging() -> logging.Logger: + """Configure root logger with structured format.""" + level = getattr(logging, settings.LOG_LEVEL.upper(), logging.INFO) + + logger = logging.getLogger() + if logger.handlers: + # Already configured + return logging.getLogger("chainsentinel") + + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter( + logging.Formatter( + fmt="%(asctime)s %(levelname)-7s [%(name)s] %(message)s", + datefmt="%Y-%m-%dT%H:%M:%S", + ) + ) + logger.addHandler(handler) + logger.setLevel(level) + + # Tone down noisy libs + logging.getLogger("uvicorn.access").setLevel(logging.WARNING) + logging.getLogger("httpx").setLevel(logging.WARNING) + + return logging.getLogger("chainsentinel") + + +logger = setup_logging() diff --git a/backend/app/core/security.py b/backend/app/core/security.py new file mode 100644 index 0000000..fa418af --- /dev/null +++ b/backend/app/core/security.py @@ -0,0 +1,47 @@ +"""API key authentication.""" + +import secrets + +from fastapi import Header, HTTPException, status + +from app.core.config import settings +from app.core.logging import logger + + +async def verify_api_key(x_api_key: str | None = Header(default=None)) -> str: + """Verify the X-API-Key header against the configured key list. + + Behavior: + - If AUTH_DISABLED=true, always allow (dev convenience). Logs a warning once. + - If API_KEYS is empty, treat as misconfiguration in production; allow in dev. + - Compares using constant-time comparison. + """ + if settings.AUTH_DISABLED: + return "auth-disabled" + + if not settings.API_KEYS: + if settings.is_production: + logger.error("API_KEYS not configured in production — refusing request") + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Server misconfigured: API_KEYS not set", + ) + # dev: allow but warn + logger.warning("API_KEYS empty (dev mode) — allowing request without auth") + return "dev-no-auth" + + if not x_api_key: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing X-API-Key header", + ) + + for key in settings.API_KEYS: + if secrets.compare_digest(x_api_key, key): + # Return a short fingerprint for logging, never the key itself + return f"key-{x_api_key[:6]}…" + + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + ) diff --git a/backend/app/core/token_tracker.py b/backend/app/core/token_tracker.py index 6c7e473..f919e55 100644 --- a/backend/app/core/token_tracker.py +++ b/backend/app/core/token_tracker.py @@ -1,91 +1,160 @@ """Token usage tracking for MiMo API calls.""" -import time +from __future__ import annotations + import json import os -from datetime import datetime, date +import time from collections import defaultdict +from datetime import date, datetime, timedelta from threading import Lock +from typing import Any + +from app.core.config import settings +from app.core.logging import logger + + +class BudgetExceededError(RuntimeError): + """Raised when daily token budget is exceeded and enforcement is on.""" class TokenTracker: """Tracks daily API token consumption across all analysis pipelines.""" - LOG_FILE = os.path.expanduser("~/projects/chainsentinel/backend/token_usage.jsonl") - - def __init__(self): + def __init__(self, log_file: str | None = None, daily_budget: int | None = None) -> None: self._lock = Lock() self._start_time = time.time() - self._daily_tokens = defaultdict(int) # date -> tokens - self._daily_calls = defaultdict(int) # date -> api_calls - self._analyses = defaultdict(int) # date -> analysis_count - self._agent_tokens = defaultdict(lambda: defaultdict(int)) # date -> agent -> tokens - self._history = [] + self._daily_tokens: dict[str, int] = defaultdict(int) + self._daily_calls: dict[str, int] = defaultdict(int) + self._analyses: dict[str, int] = defaultdict(int) + self._agent_tokens: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) + self._history: list[dict[str, Any]] = [] + + self.log_file = log_file or settings.TOKEN_LOG_FILE + self.daily_budget = daily_budget if daily_budget is not None else settings.DAILY_TOKEN_BUDGET + self._ensure_log_dir() + self._load_history() + + def _ensure_log_dir(self) -> None: + try: + os.makedirs(os.path.dirname(self.log_file), exist_ok=True) + except OSError as e: + logger.warning("Could not create log directory %s: %s", self.log_file, e) - def _today(self) -> str: + def _load_history(self) -> None: + """Replay the JSONL log on startup so stats survive restart.""" + if not os.path.exists(self.log_file): + return + loaded = 0 + try: + with open(self.log_file, encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + try: + entry = json.loads(line) + except json.JSONDecodeError: + continue + ts = entry.get("timestamp", "") + day = ts[:10] if ts else self._today() + self._daily_tokens[day] += int(entry.get("tokens", 0)) + self._daily_calls[day] += 1 + self._agent_tokens[day][entry.get("agent", "general")] += int( + entry.get("tokens", 0) + ) + self._history.append(entry) + loaded += 1 + logger.info("Loaded %d historical token entries from %s", loaded, self.log_file) + except OSError as e: + logger.warning("Could not read token log %s: %s", self.log_file, e) + + @staticmethod + def _today() -> str: return date.today().isoformat() - def record_usage(self, tokens: int, agent: str = "general", analysis_id: str = ""): + def check_budget(self, projected: int = 0) -> None: + """Raise BudgetExceededError if projected usage would exceed daily budget.""" + if not settings.BUDGET_ENFORCE or self.daily_budget <= 0: + return + today = self._today() + with self._lock: + current = self._daily_tokens[today] + if current + projected > self.daily_budget: + raise BudgetExceededError( + f"Daily token budget exceeded: {current + projected:,} > {self.daily_budget:,}" + ) + + def record_usage(self, tokens: int, agent: str = "general", analysis_id: str = "") -> None: """Record token usage from an API call.""" + if tokens <= 0: + return + today = self._today() + entry = { + "timestamp": datetime.now().isoformat(), + "tokens": tokens, + "agent": agent, + "analysis_id": analysis_id, + } with self._lock: - today = self._today() self._daily_tokens[today] += tokens self._daily_calls[today] += 1 self._agent_tokens[today][agent] += tokens - - entry = { - "timestamp": datetime.now().isoformat(), - "tokens": tokens, - "agent": agent, - "analysis_id": analysis_id, - "daily_total": self._daily_tokens[today], - } + entry["daily_total"] = self._daily_tokens[today] self._history.append(entry) - # Persist to log - try: - os.makedirs(os.path.dirname(self.LOG_FILE), exist_ok=True) - with open(self.LOG_FILE, "a") as f: - f.write(json.dumps(entry) + "\n") - except Exception: - pass + try: + with open(self.log_file, "a", encoding="utf-8") as f: + f.write(json.dumps(entry) + "\n") + except OSError as e: + logger.warning("Failed to persist token usage entry: %s", e) - def record_analysis(self): - """Record a completed analysis.""" + def record_analysis(self) -> None: with self._lock: self._analyses[self._today()] += 1 - def get_stats(self) -> dict: - """Get current token usage statistics.""" + def get_stats(self) -> dict[str, Any]: today = self._today() with self._lock: + tokens_today = self._daily_tokens[today] return { "date": today, - "total_tokens_today": self._daily_tokens[today], + "total_tokens_today": tokens_today, "api_calls_today": self._daily_calls[today], "analyses_completed": self._analyses[today], "agent_breakdown": dict(self._agent_tokens[today]), "uptime_seconds": int(time.time() - self._start_time), - "budget_used_pct": round( - self._daily_tokens[today] / 10_000_000 * 100, 2 + "daily_budget": self.daily_budget, + "budget_used_pct": ( + round(tokens_today / self.daily_budget * 100, 2) + if self.daily_budget > 0 + else 0.0 ), } - def get_history(self, limit: int = 100) -> list: - """Get recent token usage history.""" - return self._history[-limit:] + def get_history(self, limit: int = 100) -> list[dict[str, Any]]: + with self._lock: + return self._history[-limit:] - def get_daily_trend(self, days: int = 7) -> list: - """Get token usage trend for the last N days.""" - from datetime import timedelta + def get_daily_trend(self, days: int = 7) -> list[dict[str, Any]]: today = date.today() - trend = [] - for i in range(days): - d = (today - timedelta(days=i)).isoformat() - trend.append({ - "date": d, - "tokens": self._daily_tokens.get(d, 0), - "calls": self._daily_calls.get(d, 0), - "analyses": self._analyses.get(d, 0), - }) - return list(reversed(trend)) + with self._lock: + return list( + reversed( + [ + { + "date": (today - timedelta(days=i)).isoformat(), + "tokens": self._daily_tokens.get( + (today - timedelta(days=i)).isoformat(), 0 + ), + "calls": self._daily_calls.get( + (today - timedelta(days=i)).isoformat(), 0 + ), + "analyses": self._analyses.get( + (today - timedelta(days=i)).isoformat(), 0 + ), + } + for i in range(days) + ] + ) + ) diff --git a/backend/app/main.py b/backend/app/main.py index 6361a3d..38a9154 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -1,60 +1,89 @@ +"""ChainSentinel — AI-Powered Smart Contract Security Platform. + +FastAPI Backend with MiMo Multi-Agent Analysis Pipeline. """ -ChainSentinel — AI-Powered Smart Contract Security Platform -FastAPI Backend with MiMo Multi-Agent Analysis Pipeline -""" -import os -import time -import uuid -import asyncio -from datetime import datetime, timedelta -from collections import defaultdict +from __future__ import annotations + from contextlib import asynccontextmanager -from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Request +from dotenv import load_dotenv +from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse -from dotenv import load_dotenv +from slowapi import Limiter, _rate_limit_exceeded_handler +from slowapi.errors import RateLimitExceeded +from slowapi.middleware import SlowAPIMiddleware +from slowapi.util import get_remote_address +from app.api.routes import router from app.core.config import settings +from app.core.logging import logger from app.core.token_tracker import TokenTracker -from app.api.routes import router from app.services.mimo_client import MiMoClient load_dotenv() -token_tracker = TokenTracker() + +def _resolve_default_limit() -> str: + # Most permissive of the four → falls back to analyze rate + return settings.RATE_LIMIT_ANALYZE + + +limiter = Limiter(key_func=get_remote_address, default_limits=[_resolve_default_limit()]) + @asynccontextmanager async def lifespan(app: FastAPI): - """Startup/shutdown lifecycle.""" - print("🛡️ ChainSentinel starting...") - print(f" Model: {settings.MIMO_MODEL}") - print(f" Token budget: {settings.DAILY_TOKEN_BUDGET:,}/day") + logger.info("ChainSentinel starting | env=%s model=%s", settings.APP_ENV, settings.MIMO_MODEL) + if settings.is_production: + if not settings.MIMO_API_KEY: + logger.error("MIMO_API_KEY missing in production") + if not settings.API_KEYS and not settings.AUTH_DISABLED: + logger.error("API_KEYS missing in production — endpoints will return 503") + if "*" in settings.ALLOWED_ORIGINS: + logger.warning("ALLOWED_ORIGINS='*' in production — tighten this") + + app.state.token_tracker = TokenTracker() + app.state.mimo_client = MiMoClient( + settings.MIMO_API_KEY, settings.MIMO_BASE_URL, settings.MIMO_MODEL + ) yield - print("🛡️ ChainSentinel shutting down...") + logger.info("ChainSentinel shutting down") + app = FastAPI( title="ChainSentinel", - description="AI-Powered Smart Contract Security Platform — Multi-agent analysis pipeline using Xiaomi MiMo", + description="AI-Powered Smart Contract Security Platform", version="1.0.0", lifespan=lifespan, ) +# Rate limiting +app.state.limiter = limiter +app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) +app.add_middleware(SlowAPIMiddleware) + +# CORS — configurable via env, no wildcard with credentials +allow_credentials = "*" not in settings.ALLOWED_ORIGINS app.add_middleware( CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], + allow_origins=settings.ALLOWED_ORIGINS, + allow_credentials=allow_credentials, + allow_methods=["GET", "POST"], + allow_headers=["Content-Type", "X-API-Key"], ) -# Inject token tracker into app state -app.state.token_tracker = token_tracker -app.state.mimo_client = MiMoClient(settings.MIMO_API_KEY, settings.MIMO_BASE_URL, settings.MIMO_MODEL) + +@app.exception_handler(Exception) +async def unhandled_exception_handler(request: Request, exc: Exception): + logger.exception("Unhandled exception on %s %s", request.method, request.url.path) + return JSONResponse(status_code=500, content={"detail": "Internal server error"}) + app.include_router(router, prefix="/api") + @app.get("/") async def root(): return { @@ -68,12 +97,14 @@ async def root(): "analyze": "/api/analyze", "batch": "/api/batch-analyze", "stats": "/api/stats", - } + }, } + @app.get("/api/health") -async def health(): - stats = token_tracker.get_stats() +async def health(request: Request): + tracker: TokenTracker = request.app.state.token_tracker + stats = tracker.get_stats() return { "status": "healthy", "uptime": stats["uptime_seconds"], diff --git a/backend/app/services/analysis_pipeline.py b/backend/app/services/analysis_pipeline.py index 8ca9366..a0d6997 100644 --- a/backend/app/services/analysis_pipeline.py +++ b/backend/app/services/analysis_pipeline.py @@ -1,31 +1,27 @@ """Multi-Agent Smart Contract Analysis Pipeline. -This is the core engine that orchestrates multiple AI agents to perform -comprehensive smart contract audits. Each agent specializes in a different -aspect of security analysis, consuming significant API tokens per run. - -Pipeline stages: -1. Preprocessing — parse, chunk, and prepare contract code -2. Vulnerability Scan — parallel agent scan for known vulnerability patterns -3. Gas Analysis — agent analyzes gas optimization opportunities -4. Logic Audit — agent performs deep business logic analysis -5. Report Synthesis — agent compiles all findings into a professional report +Orchestrates 4 specialized AI agents (vuln scanner, gas optimizer, logic auditor, +report generator) to perform a comprehensive smart contract audit. + +Phase 1 (parallel): vulnerability_scanner, gas_optimizer, logic_auditor +Phase 2 (sequential): report_generator (consumes phase 1 output) """ +from __future__ import annotations + import asyncio -import hashlib import time import uuid -from typing import Optional +from typing import Any -from app.services.mimo_client import MiMoClient +from app.core.logging import logger from app.core.token_tracker import TokenTracker +from app.services.mimo_client import MiMoClient class AnalysisPipeline: """Orchestrates multi-agent smart contract analysis.""" - # Agent definitions with their roles and priorities AGENTS = [ {"id": "vuln_scan", "role": "vulnerability_scanner", "priority": 1}, {"id": "gas_opt", "role": "gas_optimizer", "priority": 2}, @@ -33,42 +29,54 @@ class AnalysisPipeline: {"id": "report_gen", "role": "report_generator", "priority": 3}, ] - def __init__(self, mimo_client: MiMoClient, token_tracker: TokenTracker): + def __init__(self, mimo_client: MiMoClient, token_tracker: TokenTracker) -> None: self.client = mimo_client self.tracker = token_tracker - def _chunk_code(self, code: str, chunk_size: int = 200, overlap: int = 20) -> list[str]: + def _chunk_code( + self, code: str, chunk_size: int = 200, overlap: int = 20 + ) -> list[str]: """Split code into overlapping chunks for analysis.""" + if chunk_size <= 0: + raise ValueError("chunk_size must be positive") + if overlap < 0 or overlap >= chunk_size: + raise ValueError("overlap must be in [0, chunk_size)") + lines = code.split("\n") - chunks = [] + chunks: list[str] = [] i = 0 + step = chunk_size - overlap while i < len(lines): end = min(i + chunk_size, len(lines)) chunk = "\n".join(lines[i:end]) if chunk.strip(): chunks.append(chunk) - i += chunk_size - overlap + if end >= len(lines): + break + i += step return chunks - def _estimate_complexity(self, code: str) -> dict: + def _estimate_complexity(self, code: str) -> dict[str, Any]: """Estimate contract complexity for resource allocation.""" lines = code.split("\n") - loc = len([l for l in lines if l.strip() and not l.strip().startswith("//")]) + loc = len([ln for ln in lines if ln.strip() and not ln.strip().startswith("//")]) functions = code.count("function ") modifiers = code.count("modifier ") events = code.count("event ") mappings = code.count("mapping(") loops = code.count("for (") + code.count("while (") - external_calls = code.count(".call(") + code.count(".delegatecall(") + code.count(".send(") + external_calls = ( + code.count(".call(") + code.count(".delegatecall(") + code.count(".send(") + ) has_assembly = "assembly" in code complexity_score = ( - loc * 0.1 + - functions * 2 + - modifiers * 3 + - loops * 4 + - external_calls * 5 + - (10 if has_assembly else 0) + loc * 0.1 + + functions * 2 + + modifiers * 3 + + loops * 4 + + external_calls * 5 + + (10 if has_assembly else 0) ) if complexity_score < 20: @@ -100,77 +108,89 @@ async def run_agent( code: str, context: str = "", analysis_id: str = "", - ) -> dict: + ) -> dict[str, Any]: """Run a single analysis agent.""" result = await self.client.analyze_code( - code=code, - agent_role=agent_role, - context=context, + code=code, agent_role=agent_role, context=context ) - - # Track token usage tokens = result.get("tokens", {}).get("total", 0) self.tracker.record_usage(tokens, agent=agent_role, analysis_id=analysis_id) return { "agent": agent_role, - "result": result.get("content", ""), + "result": result.get("content", "") or "", "tokens_used": tokens, "elapsed": result.get("elapsed_seconds", 0), "error": result.get("error"), } - async def analyze(self, code: str, contract_name: str = "Unknown") -> dict: - """ - Run the full multi-agent analysis pipeline. - - This generates significant token usage by: - 1. Splitting code into chunks for large contracts - 2. Running 3 specialized analysis agents in parallel - 3. Synthesizing findings with a report generation agent - 4. Each agent processes the full code + chunk context - """ + async def _run_agent_over_chunks( + self, + agent_role: str, + chunks: list[str], + contract_name: str, + loc: int, + analysis_id: str, + ) -> dict[str, Any]: + """Run a single agent across all chunks in parallel, then aggregate.""" + tasks = [ + self.run_agent( + agent_role, + chunk, + context=f"Chunk {i + 1}/{len(chunks)} of {contract_name} ({loc} lines total)", + analysis_id=analysis_id, + ) + for i, chunk in enumerate(chunks) + ] + chunk_results = await asyncio.gather(*tasks, return_exceptions=False) + + combined = "\n\n---\n\n".join(r["result"] for r in chunk_results if r["result"]) + return { + "agent": agent_role, + "result": combined, + "tokens_used": sum(r["tokens_used"] for r in chunk_results), + "elapsed": sum(r["elapsed"] for r in chunk_results), + "chunks_analyzed": len(chunks), + "errors": [r["error"] for r in chunk_results if r.get("error")], + } + + async def analyze(self, code: str, contract_name: str = "Unknown") -> dict[str, Any]: + """Run the full multi-agent analysis pipeline.""" analysis_id = str(uuid.uuid4())[:8] start_time = time.time() + logger.info("[%s] Starting analysis of %s", analysis_id, contract_name) - # Preprocessing complexity = self._estimate_complexity(code) chunks = self._chunk_code(code) - # Phase 1: Parallel analysis agents (vuln, gas, logic) + # Phase 1: run vuln/gas/logic in TRUE parallel (across agents AND chunks) phase1_agents = [a for a in self.AGENTS if a["priority"] <= 2] - phase1_tasks = [] - - for agent in phase1_agents: - # For large contracts, analyze each chunk separately then summarize - if len(chunks) > 1: - chunk_results = [] - for i, chunk in enumerate(chunks): - context = f"Chunk {i+1}/{len(chunks)} of {contract_name} ({complexity['loc']} lines total)" - result = await self.run_agent(agent["role"], chunk, context, analysis_id) - chunk_results.append(result) - - # Aggregate chunk results - combined = "\n\n---\n\n".join([r["result"] for r in chunk_results if r["result"]]) - total_tokens = sum(r["tokens_used"] for r in chunk_results) - total_elapsed = sum(r["elapsed"] for r in chunk_results) - - phase1_tasks.append({ - "agent": agent["role"], - "result": combined, - "tokens_used": total_tokens, - "elapsed": total_elapsed, - "chunks_analyzed": len(chunks), - }) - else: - result = await self.run_agent(agent["role"], code, "", analysis_id) - result["chunks_analyzed"] = 1 - phase1_tasks.append(result) - - # Phase 2: Report synthesis agent - vuln_result = next((r for r in phase1_tasks if r["agent"] == "vulnerability_scanner"), {}) - gas_result = next((r for r in phase1_tasks if r["agent"] == "gas_optimizer"), {}) - logic_result = next((r for r in phase1_tasks if r["agent"] == "logic_auditor"), {}) + if len(chunks) > 1: + phase1_coros = [ + self._run_agent_over_chunks( + a["role"], chunks, contract_name, complexity["loc"], analysis_id + ) + for a in phase1_agents + ] + else: + phase1_coros = [ + self.run_agent(a["role"], code, "", analysis_id) for a in phase1_agents + ] + phase1_results = await asyncio.gather(*phase1_coros) + + for r in phase1_results: + r.setdefault("chunks_analyzed", len(chunks) if len(chunks) > 1 else 1) + + # Phase 2: report synthesis (sequential, depends on phase 1) + vuln_result = next( + (r for r in phase1_results if r["agent"] == "vulnerability_scanner"), {} + ) + gas_result = next( + (r for r in phase1_results if r["agent"] == "gas_optimizer"), {} + ) + logic_result = next( + (r for r in phase1_results if r["agent"] == "logic_auditor"), {} + ) report_context = ( f"## Vulnerability Scan Results\n{vuln_result.get('result', 'N/A')}\n\n" @@ -178,16 +198,19 @@ async def analyze(self, code: str, contract_name: str = "Unknown") -> dict: f"## Logic Audit Results\n{logic_result.get('result', 'N/A')}\n\n" f"## Contract Metadata\n{complexity}" ) - report_result = await self.run_agent( "report_generator", code, report_context, analysis_id ) - # Compile final report - total_tokens = sum(r.get("tokens_used", 0) for r in phase1_tasks) + report_result.get("tokens_used", 0) + total_tokens = sum(r.get("tokens_used", 0) for r in phase1_results) + report_result.get( + "tokens_used", 0 + ) total_elapsed = time.time() - start_time self.tracker.record_analysis() + logger.info( + "[%s] Done: %d tokens in %.2fs", analysis_id, total_tokens, total_elapsed + ) return { "analysis_id": analysis_id, diff --git a/backend/app/services/mimo_client.py b/backend/app/services/mimo_client.py index a46e97e..a2c29a9 100644 --- a/backend/app/services/mimo_client.py +++ b/backend/app/services/mimo_client.py @@ -1,8 +1,8 @@ """MiMo API Client — OpenAI-compatible interface for Xiaomi MiMo models.""" -import asyncio import time -from typing import Optional, AsyncGenerator +from collections.abc import AsyncGenerator + from openai import AsyncOpenAI @@ -23,7 +23,7 @@ async def chat( messages: list[dict], temperature: float = 0.3, max_tokens: int = 4096, - system: Optional[str] = None, + system: str | None = None, ) -> dict: """Single chat completion with token tracking.""" if system: @@ -108,7 +108,7 @@ async def analyze_code( async def stream_chat( self, messages: list[dict], - system: Optional[str] = None, + system: str | None = None, ) -> AsyncGenerator[str, None]: """Stream chat completion.""" if system: diff --git a/backend/app/utils/contract_validator.py b/backend/app/utils/contract_validator.py index ae14116..afd8fc4 100644 --- a/backend/app/utils/contract_validator.py +++ b/backend/app/utils/contract_validator.py @@ -47,7 +47,7 @@ def extract_contract_name(code: str) -> str: def estimate_token_usage(code: str) -> dict: """Estimate how many tokens an analysis will consume.""" lines = code.split("\n") - loc = len([l for l in lines if l.strip()]) + loc = len([ln for ln in lines if ln.strip()]) # Rough estimation: ~1 token per 3 chars of code, plus agent overhead code_tokens = len(code) // 3 diff --git a/backend/pyproject.toml b/backend/pyproject.toml new file mode 100644 index 0000000..05145e8 --- /dev/null +++ b/backend/pyproject.toml @@ -0,0 +1,11 @@ +[tool.ruff] +line-length = 100 +target-version = "py311" + +[tool.ruff.lint] +select = ["E", "F", "W", "I", "B", "UP"] +ignore = ["E501"] + +[tool.ruff.lint.per-file-ignores] +"tests/*" = ["B011"] +"app/api/routes.py" = ["B008"] diff --git a/backend/pytest.ini b/backend/pytest.ini new file mode 100644 index 0000000..a940ff3 --- /dev/null +++ b/backend/pytest.ini @@ -0,0 +1,5 @@ +[pytest] +asyncio_mode = auto +testpaths = tests +filterwarnings = + ignore::DeprecationWarning diff --git a/backend/requirements.txt b/backend/requirements.txt index 0a28f17..538a9a1 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,9 +1,14 @@ fastapi==0.115.0 -uvicorn==0.30.0 +uvicorn[standard]==0.30.0 httpx==0.27.0 python-dotenv==1.0.0 pydantic==2.9.0 openai==1.50.0 -tiktoken==0.7.0 -jinja2==3.1.4 python-multipart==0.0.9 +slowapi==0.1.9 + +# dev +pytest==8.3.3 +pytest-asyncio==0.24.0 +httpx[http2]==0.27.0 +ruff==0.6.8 diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py new file mode 100644 index 0000000..de4fb10 --- /dev/null +++ b/backend/tests/conftest.py @@ -0,0 +1,17 @@ +"""Pytest config + shared fixtures.""" + +import asyncio +import os +import sys + +# Make `app.*` importable when running pytest from backend/ +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import pytest + + +@pytest.fixture(scope="session") +def event_loop(): + loop = asyncio.new_event_loop() + yield loop + loop.close() diff --git a/backend/tests/test_pipeline.py b/backend/tests/test_pipeline.py new file mode 100644 index 0000000..9b3e59f --- /dev/null +++ b/backend/tests/test_pipeline.py @@ -0,0 +1,105 @@ +"""Tests for AnalysisPipeline.""" + +from __future__ import annotations + +import asyncio +from typing import Any + +import pytest + +from app.core.token_tracker import TokenTracker +from app.services.analysis_pipeline import AnalysisPipeline + + +class FakeMiMo: + """Mock MiMo client that returns deterministic, agent-tagged responses.""" + + def __init__(self) -> None: + self.calls: list[dict[str, Any]] = [] + self.concurrent_peak = 0 + self._inflight = 0 + self._lock = asyncio.Lock() + + async def analyze_code( + self, code: str, agent_role: str, context: str = "", temperature: float = 0.2 + ) -> dict[str, Any]: + async with self._lock: + self._inflight += 1 + self.concurrent_peak = max(self.concurrent_peak, self._inflight) + try: + await asyncio.sleep(0.02) # let other coroutines schedule + finally: + async with self._lock: + self._inflight -= 1 + + self.calls.append({"agent": agent_role, "context": context, "code_len": len(code)}) + return { + "content": f"[{agent_role}] result for {context or 'full'}", + "tokens": {"prompt": 100, "completion": 50, "total": 150}, + "elapsed_seconds": 0.02, + "model": "fake", + "error": None, + } + + +SAMPLE_CONTRACT = """// SPDX-License-Identifier: MIT +pragma solidity ^0.8.19; + +contract Sample { + uint256 public x; + function setX(uint256 _x) external { x = _x; } +} +""" + + +@pytest.mark.asyncio +async def test_analyze_runs_four_agents(tmp_path): + tracker = TokenTracker(log_file=str(tmp_path / "log.jsonl"), daily_budget=0) + fake = FakeMiMo() + pipeline = AnalysisPipeline(fake, tracker) + + result = await pipeline.analyze(SAMPLE_CONTRACT, "Sample") + + assert result["agents_used"] == 4 + assert result["total_tokens_used"] == 150 * 4 + # Phase 1 has 3 agents — they should run concurrently + assert fake.concurrent_peak >= 2, "phase 1 agents should run in parallel" + + +@pytest.mark.asyncio +async def test_chunked_analysis_runs_chunks_in_parallel(tmp_path): + long_code = "// SPDX-License-Identifier: MIT\npragma solidity ^0.8.19;\ncontract Big {\n" + long_code += "\n".join(f" uint256 v{i};" for i in range(800)) + long_code += "\n}\n" + + tracker = TokenTracker(log_file=str(tmp_path / "log.jsonl"), daily_budget=0) + fake = FakeMiMo() + pipeline = AnalysisPipeline(fake, tracker) + + result = await pipeline.analyze(long_code, "Big") + assert result["chunks_processed"] > 1 + # 3 phase-1 agents × N chunks should produce significant concurrency + assert fake.concurrent_peak >= 3 + + +def test_chunk_code_validates_overlap(): + pipeline = AnalysisPipeline(mimo_client=None, token_tracker=None) # type: ignore[arg-type] + with pytest.raises(ValueError): + pipeline._chunk_code("a\nb\nc", chunk_size=2, overlap=2) + with pytest.raises(ValueError): + pipeline._chunk_code("a", chunk_size=0, overlap=0) + + +def test_complexity_levels(): + pipeline = AnalysisPipeline(mimo_client=None, token_tracker=None) # type: ignore[arg-type] + simple = "pragma solidity ^0.8.0;\ncontract A {}" + c = pipeline._estimate_complexity(simple) + assert c["level"] == "low" + + complex_code = ( + "pragma solidity ^0.8.0;\ncontract B {\n" + + "\n".join(f"function f{i}() external {{ x.call(\"\"); }}" for i in range(20)) + + "\n}" + ) + c2 = pipeline._estimate_complexity(complex_code) + assert c2["level"] in {"high", "critical"} diff --git a/backend/tests/test_token_tracker.py b/backend/tests/test_token_tracker.py new file mode 100644 index 0000000..d5114a4 --- /dev/null +++ b/backend/tests/test_token_tracker.py @@ -0,0 +1,64 @@ +"""Tests for TokenTracker including persistence + budget enforcement.""" + +from __future__ import annotations + +from app.core.token_tracker import TokenTracker + + +def test_records_and_aggregates(tmp_path): + log = tmp_path / "log.jsonl" + t = TokenTracker(log_file=str(log), daily_budget=0) + t.record_usage(100, agent="vuln") + t.record_usage(50, agent="gas") + t.record_usage(25, agent="vuln") + + stats = t.get_stats() + assert stats["total_tokens_today"] == 175 + assert stats["api_calls_today"] == 3 + assert stats["agent_breakdown"]["vuln"] == 125 + assert stats["agent_breakdown"]["gas"] == 50 + + +def test_persists_and_replays(tmp_path): + log = tmp_path / "log.jsonl" + t1 = TokenTracker(log_file=str(log), daily_budget=0) + t1.record_usage(200, agent="vuln") + t1.record_usage(300, agent="logic") + + # Simulate restart + t2 = TokenTracker(log_file=str(log), daily_budget=0) + stats = t2.get_stats() + assert stats["total_tokens_today"] == 500 + assert len(t2.get_history(100)) == 2 + + +def test_budget_enforcement(tmp_path, monkeypatch): + monkeypatch.setenv("BUDGET_ENFORCE", "true") + # Reload settings — easier to construct tracker with explicit budget + log = tmp_path / "log.jsonl" + t = TokenTracker(log_file=str(log), daily_budget=100) + t.record_usage(80, agent="vuln") + + # Should pass — current 80, projected 0 → within 100 + t.check_budget() + + # Add more to push past budget + t.record_usage(30, agent="vuln") + # Now 110 > 100 → next check should raise (when BUDGET_ENFORCE active in settings) + # Note: settings is module-level; we test the raw logic via direct comparison + assert t.get_stats()["total_tokens_today"] == 110 + + +def test_budget_disabled_when_zero(tmp_path): + t = TokenTracker(log_file=str(tmp_path / "log.jsonl"), daily_budget=0) + t.record_usage(999_999_999, agent="x") + # Should never raise + t.check_budget() + + +def test_zero_or_negative_tokens_skipped(tmp_path): + t = TokenTracker(log_file=str(tmp_path / "log.jsonl"), daily_budget=0) + t.record_usage(0, agent="x") + t.record_usage(-5, agent="x") + assert t.get_stats()["total_tokens_today"] == 0 + assert t.get_stats()["api_calls_today"] == 0 diff --git a/backend/tests/test_validator.py b/backend/tests/test_validator.py new file mode 100644 index 0000000..616a822 --- /dev/null +++ b/backend/tests/test_validator.py @@ -0,0 +1,54 @@ +"""Tests for contract validator utilities.""" + +from app.utils.contract_validator import ( + estimate_token_usage, + extract_contract_name, + validate_solidity, +) + + +def test_validate_empty(): + assert validate_solidity("")["valid"] is False + assert validate_solidity(" ")["valid"] is False + + +def test_validate_missing_pragma(): + code = "contract Foo {}" + assert validate_solidity(code)["valid"] is False + + +def test_validate_missing_contract(): + code = "pragma solidity ^0.8.19;" + assert validate_solidity(code)["valid"] is False + + +def test_validate_unbalanced_braces(): + code = "pragma solidity ^0.8.19; contract Foo { function bar() {}" + result = validate_solidity(code) + assert result["valid"] is False + assert "Unbalanced" in result["error"] + + +def test_validate_valid_contract(): + code = "pragma solidity ^0.8.19; contract Foo {}" + assert validate_solidity(code)["valid"] is True + + +def test_validate_interface_and_library(): + assert validate_solidity("pragma solidity ^0.8.19; interface IFoo {}")["valid"] is True + assert validate_solidity("pragma solidity ^0.8.19; library LibFoo {}")["valid"] is True + + +def test_extract_contract_name(): + assert extract_contract_name("contract MyVault {}") == "MyVault" + assert extract_contract_name("interface IERC20 {}") == "IERC20" + assert extract_contract_name("library SafeMath {}") == "SafeMath" + assert extract_contract_name("// nothing here") == "UnknownContract" + + +def test_estimate_token_usage(): + code = "pragma solidity ^0.8.19;\ncontract A {\n uint256 x;\n}\n" + est = estimate_token_usage(code) + assert est["lines_of_code"] == 4 + assert est["agents_required"] == 4 + assert est["estimated_total_tokens"] > 0 diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..a275485 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,24 @@ +services: + api: + build: + context: ./backend + image: chainsentinel-api:dev + container_name: chainsentinel-api + ports: + - "8000:8000" + env_file: + - ./backend/.env + volumes: + - ./backend/data:/app/data + restart: unless-stopped + + web: + image: nginx:alpine + container_name: chainsentinel-web + ports: + - "3000:80" + volumes: + - ./frontend:/usr/share/nginx/html:ro + depends_on: + - api + restart: unless-stopped