From 410481e92c8ee3b6bba16942132278eca1d065ac Mon Sep 17 00:00:00 2001 From: Divyanshi Awasthi Date: Mon, 8 Jun 2026 19:04:05 +0530 Subject: [PATCH] fix: add AI usage quotas and cost tracking --- README.md | 18 +- backend/app/main.py | 4 + backend/app/models.py | 43 ++- backend/app/routers/analyze.py | 32 +- backend/app/routers/quotas.py | 67 +++++ backend/app/routers/usage.py | 53 ++++ backend/app/schemas.py | 79 +++++ backend/app/security.py | 16 + backend/app/services/usage.py | 300 +++++++++++++++++++ backend/tests/test_usage_quotas.py | 462 +++++++++++++++++++++++++++++ 10 files changed, 1070 insertions(+), 4 deletions(-) create mode 100644 backend/app/routers/quotas.py create mode 100644 backend/app/routers/usage.py create mode 100644 backend/app/services/usage.py create mode 100644 backend/tests/test_usage_quotas.py diff --git a/README.md b/README.md index 720fb8ed..9caccac4 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,7 @@ No account required. No API key needed. Works fully offline. Fully open source. | **Download Results** | Export full report as `.txt` | | **LLM-Ready** | Plug in OpenAI, Groq, Ollama, or any OpenAI-compatible provider via env vars | | **Rate Limiting** | 30 requests/minute per IP - configurable | +| **Usage Quotas** | Track estimated AI usage costs and enforce per-user or per-team quotas | | **Swagger Docs** | Interactive API docs at `/docs` | | **Gzip Compression** | Automatic response compression | @@ -252,6 +253,21 @@ Create a share link for a saved analysis, then load it back by ID for seven days --- +### Usage and Quotas + +Authenticated users can track estimated provider usage and configure quotas: + +| Endpoint | Detail | +|---|---| +| `GET /usage/summary` | Request, token, cost, and alert summary for the authenticated user or `team_id` | +| `GET /usage/costs` | Estimated cost breakdown grouped by provider and model | +| `POST /quotas` | Create or update user, team, or global quota limits | +| `GET /quotas` | Return the applicable quota configuration | + +When a configured quota would be exceeded, `/analyze/` returns `429` before running the analysis. + +--- + ## Project Structure ``` @@ -525,4 +541,4 @@ MIT © [Darshan G K](https://github.com/imDarshanGK) Built for the open source community  ·  GSSoC 2026 - \ No newline at end of file + diff --git a/backend/app/main.py b/backend/app/main.py index e1903cfd..a44e9ed7 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -21,10 +21,12 @@ debugging, explanation, history, + quotas, share, subscribe, suggestions, upload_file, + usage, user_data, ) from .routers import health as health_router @@ -157,6 +159,8 @@ async def add_cache_header(request: Request, call_next): app.include_router(history.router, prefix="/history", tags=["History"]) app.include_router(auth.router) app.include_router(chat.router) +app.include_router(usage.router) +app.include_router(quotas.router) app.include_router(share.router) app.include_router(user_data.router) app.include_router(upload_file.router, prefix="/upload", tags=['Upload File'] ) diff --git a/backend/app/models.py b/backend/app/models.py index 44e6cf54..af8ca6d2 100644 --- a/backend/app/models.py +++ b/backend/app/models.py @@ -1,6 +1,6 @@ from datetime import UTC, datetime -from sqlalchemy import DateTime, ForeignKey, Integer, String, Text +from sqlalchemy import DateTime, Float, ForeignKey, Integer, String, Text from sqlalchemy.orm import Mapped, mapped_column, relationship from .database import Base @@ -78,3 +78,44 @@ class SharedSnippet(Base): created_at: Mapped[datetime] = mapped_column( DateTime, default=lambda: datetime.now(UTC) ) + + +class UsageLog(Base): + """Durable record of estimated AI provider usage for one request.""" + + __tablename__ = "usage_logs" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True) + user_id: Mapped[int | None] = mapped_column(ForeignKey("users.id"), nullable=True, index=True) + team_id: Mapped[str | None] = mapped_column(String(120), nullable=True, index=True) + endpoint: Mapped[str] = mapped_column(String(80), index=True) + provider: Mapped[str] = mapped_column(String(80), index=True) + model: Mapped[str] = mapped_column(String(120), index=True) + prompt_tokens: Mapped[int] = mapped_column(Integer, default=0) + completion_tokens: Mapped[int] = mapped_column(Integer, default=0) + total_tokens: Mapped[int] = mapped_column(Integer, default=0) + estimated_cost_usd: Mapped[float] = mapped_column(Float, default=0.0) + created_at: Mapped[datetime] = mapped_column( + DateTime, default=lambda: datetime.now(UTC), index=True + ) + + +class QuotaConfig(Base): + """Configurable usage quota for a user, team, or global scope.""" + + __tablename__ = "quota_configs" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True) + user_id: Mapped[int | None] = mapped_column(ForeignKey("users.id"), nullable=True, index=True) + team_id: Mapped[str | None] = mapped_column(String(120), nullable=True, index=True) + period: Mapped[str] = mapped_column(String(20), default="monthly") + max_requests: Mapped[int | None] = mapped_column(Integer, nullable=True) + max_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True) + max_cost_usd: Mapped[float | None] = mapped_column(Float, nullable=True) + alert_thresholds: Mapped[str] = mapped_column(String(120), default="0.8,1.0") + created_at: Mapped[datetime] = mapped_column( + DateTime, default=lambda: datetime.now(UTC) + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC) + ) diff --git a/backend/app/routers/analyze.py b/backend/app/routers/analyze.py index ede1fcd5..a6eb7e2c 100644 --- a/backend/app/routers/analyze.py +++ b/backend/app/routers/analyze.py @@ -8,10 +8,24 @@ from io import BytesIO from pathlib import PurePosixPath -from fastapi import APIRouter, File, HTTPException, Query, Request, Response, UploadFile +from fastapi import ( + APIRouter, + Depends, + File, + Header, + HTTPException, + Query, + Request, + Response, + UploadFile, +) from fastapi.responses import StreamingResponse +from sqlalchemy.orm import Session +from ..database import get_db +from ..models import User from ..schemas import AnalyzeResponse, CodeRequest, ZipAnalyzeResponse +from ..security import get_optional_user from ..services.cache import cache from ..services.code_assistant import ( detect_language, @@ -20,6 +34,7 @@ run_explanation, run_suggestions, ) +from ..services.usage import enforce_quota, estimate_usage, log_usage from ..sanitize import sanitize_code_input, sanitize_language_hint router = APIRouter() @@ -192,15 +207,28 @@ async def analyze_stream_get( response_model=AnalyzeResponse, summary="Run full analysis (explain + debug + suggest)", ) -async def analyze(req: CodeRequest, response: Response): +async def analyze( + req: CodeRequest, + response: Response, + current_user: User | None = Depends(get_optional_user), + db: Session = Depends(get_db), + team_id: str | None = Header(default=None, alias="X-Team-Id"), +): + user_id = current_user.id if current_user else None + preflight_estimate = estimate_usage(req.code) + enforce_quota(db, preflight_estimate, user_id=user_id, team_id=team_id) + cache_input = f"{req.language or 'auto'}\n{req.code}" cached_payload = cache.get("analyze:v1", cache_input) if cached_payload is not None: response.headers["X-Cache"] = "HIT" + log_usage(db, "/analyze/", preflight_estimate, user_id=user_id, team_id=team_id) return cached_payload payload = full_analysis(req.code, req.language) + usage_estimate = estimate_usage(req.code, json.dumps(payload, sort_keys=True)) + log_usage(db, "/analyze/", usage_estimate, user_id=user_id, team_id=team_id) cache.set("analyze:v1", cache_input, payload) diff --git a/backend/app/routers/quotas.py b/backend/app/routers/quotas.py new file mode 100644 index 00000000..b2b090ba --- /dev/null +++ b/backend/app/routers/quotas.py @@ -0,0 +1,67 @@ +"""Quota management endpoints for usage enforcement.""" + +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy import select +from sqlalchemy.orm import Session + +from ..database import get_db +from ..models import QuotaConfig, User +from ..schemas import QuotaResponse, QuotaUpsertRequest +from ..security import get_current_user +from ..services.usage import ensure_usage_tables, find_applicable_quota, quota_to_dict + +router = APIRouter(prefix="/quotas", tags=["Quotas"]) + + +@router.post("", response_model=QuotaResponse) +def upsert_quota( + payload: QuotaUpsertRequest, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + """Create or update a user, team, or global quota configuration.""" + ensure_usage_tables(db) + user_id = payload.user_id + if payload.team_id is None: + user_id = current_user.id if user_id is None else user_id + if user_id != current_user.id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Cannot manage another user's quota", + ) + + query = select(QuotaConfig).where( + QuotaConfig.user_id.is_(None) + if user_id is None + else QuotaConfig.user_id == user_id, + QuotaConfig.team_id.is_(None) + if payload.team_id is None + else QuotaConfig.team_id == payload.team_id, + ) + quota = db.execute(query).scalar_one_or_none() + if quota is None: + quota = QuotaConfig(user_id=user_id, team_id=payload.team_id) + db.add(quota) + + quota.period = payload.period + quota.max_requests = payload.max_requests + quota.max_tokens = payload.max_tokens + quota.max_cost_usd = payload.max_cost_usd + quota.alert_thresholds = ",".join(str(value) for value in payload.alert_thresholds) + db.commit() + db.refresh(quota) + return quota_to_dict(quota) + + +@router.get("", response_model=QuotaResponse | None) +def get_quota( + user_id: int | None = Query(default=None), + team_id: str | None = Query(default=None, max_length=120), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + """Return the applicable quota for a user, team, or global scope.""" + if user_id is None and team_id is None: + user_id = current_user.id + quota = find_applicable_quota(db, user_id=user_id, team_id=team_id) + return quota_to_dict(quota) if quota is not None else None diff --git a/backend/app/routers/usage.py b/backend/app/routers/usage.py new file mode 100644 index 00000000..87b700f0 --- /dev/null +++ b/backend/app/routers/usage.py @@ -0,0 +1,53 @@ +"""Usage reporting endpoints for AI provider costs and quota alerts.""" + +from fastapi import APIRouter, Depends, Query +from sqlalchemy.orm import Session + +from ..database import get_db +from ..models import User +from ..schemas import UsageCostsResponse, UsageSummaryResponse +from ..security import get_current_user +from ..services.usage import aggregate_usage, build_alerts, find_applicable_quota, provider_costs + +router = APIRouter(prefix="/usage", tags=["Usage"]) + + +@router.get("/summary", response_model=UsageSummaryResponse) +def usage_summary( + period: str = Query("monthly", pattern="^(daily|monthly)$"), + team_id: str | None = Query(default=None, max_length=120), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + """Return current usage totals for the authenticated user or a team.""" + user_id = None if team_id else current_user.id + totals = aggregate_usage(db, period=period, user_id=user_id, team_id=team_id) + quota = find_applicable_quota(db, user_id=user_id, team_id=team_id) + return { + "scope": "team" if team_id else "user", + "user_id": user_id, + "team_id": team_id, + "period": period, + **totals, + "alerts": build_alerts(totals, quota), + } + + +@router.get("/costs", response_model=UsageCostsResponse) +def usage_costs( + period: str = Query("monthly", pattern="^(daily|monthly)$"), + team_id: str | None = Query(default=None, max_length=120), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + """Return estimated usage cost grouped by provider and model.""" + user_id = None if team_id else current_user.id + providers = provider_costs(db, period=period, user_id=user_id, team_id=team_id) + return { + "scope": "team" if team_id else "user", + "user_id": user_id, + "team_id": team_id, + "period": period, + "providers": providers, + "total_cost_usd": round(sum(item["estimated_cost_usd"] for item in providers), 6), + } diff --git a/backend/app/schemas.py b/backend/app/schemas.py index ac17529e..db9c2999 100644 --- a/backend/app/schemas.py +++ b/backend/app/schemas.py @@ -353,6 +353,85 @@ class ChatMessageResponse(BaseModel): reply: str +class UsageAlert(BaseModel): + """Alert emitted when usage reaches a configured quota threshold.""" + + metric: str + threshold: float + percent_used: float + message: str + + +class UsageSummaryResponse(BaseModel): + """Aggregated usage totals for a user, team, or global scope.""" + + scope: str + user_id: int | None = None + team_id: str | None = None + period: str + request_count: int + prompt_tokens: int + completion_tokens: int + total_tokens: int + estimated_cost_usd: float + alerts: list[UsageAlert] = Field(default_factory=list) + + +class UsageCostsResponse(BaseModel): + """Provider-level usage and estimated cost breakdown.""" + + scope: str + user_id: int | None = None + team_id: str | None = None + period: str + providers: list[dict[str, Any]] + total_cost_usd: float + + +class QuotaUpsertRequest(BaseModel): + """Create or update quota limits for a user, team, or global scope.""" + + user_id: int | None = None + team_id: str | None = Field(default=None, max_length=120) + period: str = Field(default="monthly", pattern="^(daily|monthly)$") + max_requests: int | None = Field(default=None, gt=0) + max_tokens: int | None = Field(default=None, gt=0) + max_cost_usd: float | None = Field(default=None, gt=0) + alert_thresholds: list[float] = Field(default_factory=lambda: [0.8, 1.0]) + + @field_validator("alert_thresholds") + @classmethod + def validate_alert_thresholds(cls, value: list[float]) -> list[float]: + if not value: + return [0.8, 1.0] + if any(threshold <= 0 or threshold > 1 for threshold in value): + raise ValueError("alert thresholds must be between 0 and 1") + return sorted(set(value)) + + @model_validator(mode="after") + def ensure_limit_present(self) -> "QuotaUpsertRequest": + if ( + self.max_requests is None + and self.max_tokens is None + and self.max_cost_usd is None + ): + raise ValueError("at least one quota limit is required") + return self + + +class QuotaResponse(BaseModel): + """Stored quota configuration.""" + + id: int + user_id: int | None = None + team_id: str | None = None + period: str + max_requests: int | None = None + max_tokens: int | None = None + max_cost_usd: float | None = None + alert_thresholds: list[float] + + # ── Explanation / Debugging / Suggestions response models ─────────────────── class ExplanationResponse(BaseModel): language: str diff --git a/backend/app/security.py b/backend/app/security.py index 49c6e840..0a3731af 100644 --- a/backend/app/security.py +++ b/backend/app/security.py @@ -71,3 +71,19 @@ def get_current_user( ) return user + + +def get_optional_user( + credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme), + db: Session = Depends(get_db), +) -> User | None: + """Return the authenticated user when a valid bearer token is provided.""" + if credentials is None: + return None + + try: + user_id = decode_access_token(credentials.credentials) + except Exception: + return None + + return db.get(User, user_id) diff --git a/backend/app/services/usage.py b/backend/app/services/usage.py new file mode 100644 index 00000000..2fd8ce78 --- /dev/null +++ b/backend/app/services/usage.py @@ -0,0 +1,300 @@ +"""Usage accounting, cost estimation, and quota enforcement helpers.""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import UTC, datetime +from math import ceil + +from fastapi import HTTPException, status +from sqlalchemy import func, select +from sqlalchemy.orm import Session + +from ..config import settings +from ..database import Base +from ..models import QuotaConfig, UsageLog + +PRICE_PER_1K_TOKENS_USD = { + "rule-based": 0.0, + "local": 0.0, + "ollama": 0.0, + "openai": 0.002, + "groq": 0.0005, + "together": 0.001, + "openai-compatible": 0.002, +} + + +def ensure_usage_tables(db: Session) -> None: + """Create usage tables when running without an external migration tool.""" + Base.metadata.create_all(bind=db.get_bind()) + + +@dataclass(frozen=True) +class UsageEstimate: + """Estimated usage metadata for a request.""" + + provider: str + model: str + prompt_tokens: int + completion_tokens: int + total_tokens: int + estimated_cost_usd: float + + +def estimate_tokens(text: str) -> int: + """Estimate token count from text length without provider-specific libraries.""" + return max(1, ceil(len(text) / 4)) + + +def estimate_usage( + prompt_text: str, + completion_text: str = "", + provider: str | None = None, + model: str | None = None, +) -> UsageEstimate: + """Return a conservative usage and cost estimate for a provider request.""" + provider_name = (provider or settings.ai_provider or "rule-based").lower() + model_name = model or settings.ai_model + prompt_tokens = estimate_tokens(prompt_text) + completion_tokens = estimate_tokens(completion_text) if completion_text else 0 + total_tokens = prompt_tokens + completion_tokens + rate = PRICE_PER_1K_TOKENS_USD.get(provider_name, PRICE_PER_1K_TOKENS_USD["openai-compatible"]) + return UsageEstimate( + provider=provider_name, + model=model_name, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + estimated_cost_usd=round((total_tokens / 1000) * rate, 6), + ) + + +def period_start(period: str, now: datetime | None = None) -> datetime: + """Return the UTC start timestamp for a supported quota period.""" + current = now or datetime.now(UTC) + if period == "daily": + return current.replace(hour=0, minute=0, second=0, microsecond=0) + return current.replace(day=1, hour=0, minute=0, second=0, microsecond=0) + + +def parse_thresholds(raw_value: str) -> list[float]: + """Parse comma-separated alert threshold values from a quota row.""" + thresholds: list[float] = [] + for item in raw_value.split(","): + try: + value = float(item.strip()) + except ValueError: + continue + if 0 < value <= 1: + thresholds.append(value) + return sorted(set(thresholds)) or [0.8, 1.0] + + +def quota_to_dict(quota: QuotaConfig) -> dict: + """Serialize a quota model into an API-friendly dictionary.""" + return { + "id": quota.id, + "user_id": quota.user_id, + "team_id": quota.team_id, + "period": quota.period, + "max_requests": quota.max_requests, + "max_tokens": quota.max_tokens, + "max_cost_usd": quota.max_cost_usd, + "alert_thresholds": parse_thresholds(quota.alert_thresholds), + } + + +def find_applicable_quota( + db: Session, + user_id: int | None = None, + team_id: str | None = None, +) -> QuotaConfig | None: + """Find the most specific quota for a team, user, or global scope.""" + ensure_usage_tables(db) + if team_id: + quota = db.execute( + select(QuotaConfig) + .where(QuotaConfig.team_id == team_id) + .order_by(QuotaConfig.id.desc()) + .limit(1) + ).scalar_one_or_none() + if quota is not None: + return quota + + if user_id is not None: + quota = db.execute( + select(QuotaConfig) + .where(QuotaConfig.user_id == user_id, QuotaConfig.team_id.is_(None)) + .order_by(QuotaConfig.id.desc()) + .limit(1) + ).scalar_one_or_none() + if quota is not None: + return quota + + return db.execute( + select(QuotaConfig) + .where(QuotaConfig.user_id.is_(None), QuotaConfig.team_id.is_(None)) + .order_by(QuotaConfig.id.desc()) + .limit(1) + ).scalar_one_or_none() + + +def aggregate_usage( + db: Session, + period: str = "monthly", + user_id: int | None = None, + team_id: str | None = None, +) -> dict: + """Aggregate usage totals for the requested scope and period.""" + ensure_usage_tables(db) + query = select( + func.count(UsageLog.id), + func.coalesce(func.sum(UsageLog.prompt_tokens), 0), + func.coalesce(func.sum(UsageLog.completion_tokens), 0), + func.coalesce(func.sum(UsageLog.total_tokens), 0), + func.coalesce(func.sum(UsageLog.estimated_cost_usd), 0.0), + ).where(UsageLog.created_at >= period_start(period)) + + if team_id: + query = query.where(UsageLog.team_id == team_id) + elif user_id is not None: + query = query.where(UsageLog.user_id == user_id) + else: + query = query.where(UsageLog.user_id.is_(None), UsageLog.team_id.is_(None)) + + request_count, prompt_tokens, completion_tokens, total_tokens, cost = db.execute(query).one() + return { + "request_count": int(request_count or 0), + "prompt_tokens": int(prompt_tokens or 0), + "completion_tokens": int(completion_tokens or 0), + "total_tokens": int(total_tokens or 0), + "estimated_cost_usd": round(float(cost or 0.0), 6), + } + + +def build_alerts(totals: dict, quota: QuotaConfig | None) -> list[dict]: + """Build alert payloads for quota thresholds reached by current usage.""" + if quota is None: + return [] + + limits = { + "requests": (totals["request_count"], quota.max_requests), + "tokens": (totals["total_tokens"], quota.max_tokens), + "cost": (totals["estimated_cost_usd"], quota.max_cost_usd), + } + alerts: list[dict] = [] + for metric, (used, limit) in limits.items(): + if not limit: + continue + percent_used = float(used) / float(limit) + for threshold in parse_thresholds(quota.alert_thresholds): + if percent_used >= threshold: + alerts.append( + { + "metric": metric, + "threshold": threshold, + "percent_used": round(percent_used * 100, 2), + "message": f"{metric} usage reached {round(threshold * 100)}% of quota", + } + ) + return alerts + + +def enforce_quota( + db: Session, + estimate: UsageEstimate, + user_id: int | None = None, + team_id: str | None = None, +) -> None: + """Raise 429 when the applicable quota would be exceeded by a request.""" + quota = find_applicable_quota(db, user_id=user_id, team_id=team_id) + if quota is None: + return + + totals = aggregate_usage(db, quota.period, user_id=user_id, team_id=team_id) + projected = { + "request_count": totals["request_count"] + 1, + "total_tokens": totals["total_tokens"] + estimate.total_tokens, + "estimated_cost_usd": totals["estimated_cost_usd"] + estimate.estimated_cost_usd, + } + + exceeded = [] + if quota.max_requests and projected["request_count"] > quota.max_requests: + exceeded.append("requests") + if quota.max_tokens and projected["total_tokens"] > quota.max_tokens: + exceeded.append("tokens") + if quota.max_cost_usd and projected["estimated_cost_usd"] > quota.max_cost_usd: + exceeded.append("cost") + + if exceeded: + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail={ + "message": "Usage quota exceeded", + "exceeded": exceeded, + "period": quota.period, + }, + ) + + +def log_usage( + db: Session, + endpoint: str, + estimate: UsageEstimate, + user_id: int | None = None, + team_id: str | None = None, +) -> UsageLog: + """Persist an estimated usage log entry.""" + ensure_usage_tables(db) + record = UsageLog( + user_id=user_id, + team_id=team_id, + endpoint=endpoint, + provider=estimate.provider, + model=estimate.model, + prompt_tokens=estimate.prompt_tokens, + completion_tokens=estimate.completion_tokens, + total_tokens=estimate.total_tokens, + estimated_cost_usd=estimate.estimated_cost_usd, + ) + db.add(record) + db.commit() + db.refresh(record) + return record + + +def provider_costs( + db: Session, + period: str = "monthly", + user_id: int | None = None, + team_id: str | None = None, +) -> list[dict]: + """Return usage totals grouped by AI provider and model.""" + ensure_usage_tables(db) + query = select( + UsageLog.provider, + UsageLog.model, + func.count(UsageLog.id), + func.coalesce(func.sum(UsageLog.total_tokens), 0), + func.coalesce(func.sum(UsageLog.estimated_cost_usd), 0.0), + ).where(UsageLog.created_at >= period_start(period)) + + if team_id: + query = query.where(UsageLog.team_id == team_id) + elif user_id is not None: + query = query.where(UsageLog.user_id == user_id) + else: + query = query.where(UsageLog.user_id.is_(None), UsageLog.team_id.is_(None)) + + rows = db.execute(query.group_by(UsageLog.provider, UsageLog.model)).all() + return [ + { + "provider": provider, + "model": model, + "request_count": int(request_count or 0), + "total_tokens": int(total_tokens or 0), + "estimated_cost_usd": round(float(cost or 0.0), 6), + } + for provider, model, request_count, total_tokens, cost in rows + ] diff --git a/backend/tests/test_usage_quotas.py b/backend/tests/test_usage_quotas.py new file mode 100644 index 00000000..bf4027a9 --- /dev/null +++ b/backend/tests/test_usage_quotas.py @@ -0,0 +1,462 @@ +"""Tests for AI usage tracking, cost summaries, and quota enforcement.""" + +from datetime import UTC, datetime, timedelta +import os +import sys + +import pytest +from fastapi.testclient import TestClient +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import StaticPool + +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + +from app.database import Base, get_db +from app.main import app as fastapi_app +from app.main import _request_counts +from app.models import QuotaConfig, UsageLog +from app.services.cache import cache +from app.services.usage import ( + UsageEstimate, + aggregate_usage, + build_alerts, + enforce_quota, + estimate_tokens, + estimate_usage, + parse_thresholds, + provider_costs, +) + + +TEST_ENGINE = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, +) +TEST_SESSION_LOCAL = sessionmaker(bind=TEST_ENGINE) + + +def _override_db(): + db = TEST_SESSION_LOCAL() + try: + yield db + finally: + db.close() + + +@pytest.fixture +def client(): + previous_override = fastapi_app.dependency_overrides.get(get_db) + fastapi_app.dependency_overrides[get_db] = _override_db + cache.clear_memory() + _request_counts.clear() + with TestClient(fastapi_app) as test_client: + yield test_client + cache.clear_memory() + _request_counts.clear() + if previous_override is None: + fastapi_app.dependency_overrides.pop(get_db, None) + else: + fastapi_app.dependency_overrides[get_db] = previous_override + + +@pytest.fixture(autouse=True) +def _recreate_tables(): + Base.metadata.create_all(bind=TEST_ENGINE) + yield + Base.metadata.drop_all(bind=TEST_ENGINE) + + +def _auth_headers(client: TestClient) -> dict[str, str]: + response = client.post( + "/auth/signup", + json={"email": "usage.user@example.com", "password": "StrongPass123!"}, + ) + assert response.status_code == 200 + token = response.json()["access_token"] + return {"Authorization": f"Bearer {token}"} + + +def test_usage_summary_and_costs_track_analyze_requests(client): + headers = _auth_headers(client) + quota_response = client.post( + "/quotas", + headers=headers, + json={"max_requests": 5, "alert_thresholds": [0.8, 1.0]}, + ) + assert quota_response.status_code == 200 + assert quota_response.json()["user_id"] is not None + + analyze_response = client.post( + "/analyze/", + headers=headers, + json={"code": "print('hello')", "language": "python"}, + ) + assert analyze_response.status_code == 200 + + summary_response = client.get("/usage/summary", headers=headers) + assert summary_response.status_code == 200 + summary = summary_response.json() + assert summary["scope"] == "user" + assert summary["request_count"] == 1 + assert summary["total_tokens"] > 0 + assert summary["estimated_cost_usd"] == 0 + + costs_response = client.get("/usage/costs", headers=headers) + assert costs_response.status_code == 200 + costs = costs_response.json() + assert costs["providers"][0]["provider"] == "rule-based" + assert costs["providers"][0]["request_count"] == 1 + + +def test_analyze_returns_429_when_user_quota_is_exceeded(client): + headers = _auth_headers(client) + quota_response = client.post( + "/quotas", + headers=headers, + json={"period": "monthly", "max_requests": 1}, + ) + assert quota_response.status_code == 200 + + first_response = client.post( + "/analyze/", + headers=headers, + json={"code": "x = 1", "language": "python"}, + ) + assert first_response.status_code == 200 + + second_response = client.post( + "/analyze/", + headers=headers, + json={"code": "x = 2", "language": "python"}, + ) + assert second_response.status_code == 429 + assert second_response.json()["detail"]["message"] == "Usage quota exceeded" + assert "requests" in second_response.json()["detail"]["exceeded"] + + +def test_usage_summary_emits_alerts_at_configured_thresholds(client): + headers = _auth_headers(client) + quota_response = client.post( + "/quotas", + headers=headers, + json={"max_requests": 1, "alert_thresholds": [0.8, 1.0]}, + ) + assert quota_response.status_code == 200 + + analyze_response = client.post( + "/analyze/", + headers=headers, + json={"code": "print('alert')", "language": "python"}, + ) + assert analyze_response.status_code == 200 + + summary_response = client.get("/usage/summary", headers=headers) + assert summary_response.status_code == 200 + alerts = summary_response.json()["alerts"] + assert [alert["threshold"] for alert in alerts] == [0.8, 1.0] + assert all(alert["metric"] == "requests" for alert in alerts) + + +def test_usage_and_quota_endpoints_require_authentication(client): + for method, url in [ + ("get", "/usage/summary"), + ("get", "/usage/costs"), + ("get", "/quotas"), + ("post", "/quotas"), + ]: + if method == "post": + response = client.post(url, json={}) + else: + response = client.get(url) + assert response.status_code in (401, 422) + + +def test_quota_payload_validation_rejects_empty_limits_and_bad_thresholds(client): + headers = _auth_headers(client) + + empty_limits = client.post("/quotas", headers=headers, json={}) + assert empty_limits.status_code == 422 + + bad_threshold = client.post( + "/quotas", + headers=headers, + json={"max_requests": 5, "alert_thresholds": [0, 1.2]}, + ) + assert bad_threshold.status_code == 422 + + bad_period = client.get("/usage/summary?period=yearly", headers=headers) + assert bad_period.status_code == 422 + + +def test_user_cannot_manage_another_users_quota(client): + first_headers = _auth_headers(client) + second_signup = client.post( + "/auth/signup", + json={"email": "other.user@example.com", "password": "StrongPass123!"}, + ) + assert second_signup.status_code == 200 + + response = client.post( + "/quotas", + headers=first_headers, + json={"user_id": second_signup.json()["user_id"], "max_requests": 5}, + ) + assert response.status_code == 403 + + +def test_quota_upsert_updates_existing_user_quota(client): + headers = _auth_headers(client) + first = client.post("/quotas", headers=headers, json={"max_requests": 3}) + second = client.post("/quotas", headers=headers, json={"max_requests": 7}) + + assert first.status_code == 200 + assert second.status_code == 200 + assert second.json()["id"] == first.json()["id"] + assert second.json()["max_requests"] == 7 + + fetched = client.get("/quotas", headers=headers) + assert fetched.status_code == 200 + assert fetched.json()["max_requests"] == 7 + + +def test_team_quota_takes_precedence_over_user_quota(client): + headers = _auth_headers(client) + user_quota = client.post("/quotas", headers=headers, json={"max_requests": 1}) + team_quota = client.post( + "/quotas", + headers=headers, + json={"team_id": "core-platform", "max_requests": 2}, + ) + assert user_quota.status_code == 200 + assert team_quota.status_code == 200 + + for code in ["x = 1", "x = 2"]: + response = client.post( + "/analyze/", + headers={**headers, "X-Team-Id": "core-platform"}, + json={"code": code, "language": "python"}, + ) + assert response.status_code == 200 + + blocked = client.post( + "/analyze/", + headers={**headers, "X-Team-Id": "core-platform"}, + json={"code": "x = 3", "language": "python"}, + ) + assert blocked.status_code == 429 + assert "requests" in blocked.json()["detail"]["exceeded"] + + +def test_token_quota_blocks_before_analysis_runs(client): + headers = _auth_headers(client) + quota_response = client.post("/quotas", headers=headers, json={"max_tokens": 1}) + assert quota_response.status_code == 200 + + response = client.post( + "/analyze/", + headers=headers, + json={"code": "print('more than one estimated token')", "language": "python"}, + ) + assert response.status_code == 429 + assert "tokens" in response.json()["detail"]["exceeded"] + + +def test_cached_analyze_requests_are_still_logged(client): + headers = _auth_headers(client) + payload = {"code": "print('cached')", "language": "python"} + + first = client.post("/analyze/", headers=headers, json=payload) + second = client.post("/analyze/", headers=headers, json=payload) + + assert first.status_code == 200 + assert second.status_code == 200 + assert first.headers["X-Cache"] == "MISS" + assert second.headers["X-Cache"] == "HIT" + + summary = client.get("/usage/summary", headers=headers).json() + assert summary["request_count"] == 2 + + +def test_team_usage_summary_and_costs_are_isolated(client): + headers = _auth_headers(client) + + team_response = client.post( + "/analyze/", + headers={**headers, "X-Team-Id": "team-a"}, + json={"code": "print('team')", "language": "python"}, + ) + user_response = client.post( + "/analyze/", + headers=headers, + json={"code": "print('user')", "language": "python"}, + ) + assert team_response.status_code == 200 + assert user_response.status_code == 200 + + team_summary = client.get( + "/usage/summary?team_id=team-a", + headers=headers, + ).json() + team_costs = client.get("/usage/costs?team_id=team-a", headers=headers).json() + + assert team_summary["scope"] == "team" + assert team_summary["team_id"] == "team-a" + assert team_summary["request_count"] == 1 + assert team_costs["providers"][0]["request_count"] == 1 + + +def test_estimate_helpers_cover_empty_text_unknown_provider_and_thresholds(): + assert estimate_tokens("") == 1 + + estimate = estimate_usage("abcd", provider="mystery-provider", model="custom") + assert estimate.provider == "mystery-provider" + assert estimate.model == "custom" + assert estimate.total_tokens == 1 + assert estimate.estimated_cost_usd == 0.000002 + + assert parse_thresholds("bad, 0, 0.8, 1.5, 1") == [0.8, 1.0] + assert parse_thresholds("bad") == [0.8, 1.0] + + +def test_daily_aggregation_excludes_old_usage_rows(): + db = TEST_SESSION_LOCAL() + try: + now = datetime.now(UTC) + db.add_all( + [ + UsageLog( + user_id=1, + endpoint="/analyze/", + provider="rule-based", + model="qyverix-engine-v3", + prompt_tokens=10, + completion_tokens=0, + total_tokens=10, + estimated_cost_usd=0, + created_at=now - timedelta(days=2), + ), + UsageLog( + user_id=1, + endpoint="/analyze/", + provider="rule-based", + model="qyverix-engine-v3", + prompt_tokens=4, + completion_tokens=1, + total_tokens=5, + estimated_cost_usd=0, + created_at=now, + ), + ] + ) + db.commit() + + daily = aggregate_usage(db, period="daily", user_id=1) + monthly = aggregate_usage(db, period="monthly", user_id=1) + finally: + db.close() + + assert daily["request_count"] == 1 + assert daily["total_tokens"] == 5 + assert monthly["request_count"] == 2 + assert monthly["total_tokens"] == 15 + + +def test_provider_costs_groups_by_provider_and_model(): + db = TEST_SESSION_LOCAL() + try: + db.add_all( + [ + UsageLog( + user_id=2, + endpoint="/analyze/", + provider="openai", + model="gpt-4o-mini", + prompt_tokens=10, + completion_tokens=5, + total_tokens=15, + estimated_cost_usd=0.00003, + ), + UsageLog( + user_id=2, + endpoint="/chat/message", + provider="openai", + model="gpt-4o-mini", + prompt_tokens=20, + completion_tokens=5, + total_tokens=25, + estimated_cost_usd=0.00005, + ), + UsageLog( + user_id=2, + endpoint="/chat/message", + provider="groq", + model="llama3", + prompt_tokens=10, + completion_tokens=0, + total_tokens=10, + estimated_cost_usd=0.000005, + ), + ] + ) + db.commit() + costs = provider_costs(db, user_id=2) + finally: + db.close() + + by_provider = {(item["provider"], item["model"]): item for item in costs} + assert by_provider[("openai", "gpt-4o-mini")]["request_count"] == 2 + assert by_provider[("openai", "gpt-4o-mini")]["total_tokens"] == 40 + assert by_provider[("groq", "llama3")]["request_count"] == 1 + + +def test_enforce_quota_blocks_cost_and_combined_limits(): + db = TEST_SESSION_LOCAL() + try: + db.add( + QuotaConfig( + user_id=9, + period="monthly", + max_requests=10, + max_tokens=100, + max_cost_usd=0.01, + ) + ) + db.commit() + + estimate = UsageEstimate( + provider="openai", + model="gpt-4o-mini", + prompt_tokens=10, + completion_tokens=10, + total_tokens=20, + estimated_cost_usd=0.02, + ) + with pytest.raises(Exception) as exc_info: + enforce_quota(db, estimate, user_id=9) + finally: + db.close() + + assert getattr(exc_info.value, "status_code") == 429 + assert "cost" in exc_info.value.detail["exceeded"] + + +def test_build_alerts_returns_empty_without_quota_and_reports_multiple_metrics(): + assert build_alerts({"request_count": 1, "total_tokens": 1, "estimated_cost_usd": 0}, None) == [] + + quota = QuotaConfig( + max_requests=2, + max_tokens=10, + max_cost_usd=1, + alert_thresholds="0.5,1", + ) + alerts = build_alerts( + {"request_count": 2, "total_tokens": 5, "estimated_cost_usd": 0.5}, + quota, + ) + + metrics = [alert["metric"] for alert in alerts] + assert metrics.count("requests") == 2 + assert metrics.count("tokens") == 1 + assert metrics.count("cost") == 1