Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion backend/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ def _bool_env(name: str, default: bool) -> bool:
return raw_value.strip().lower() in {"1", "true", "yes", "on"}


def _required_env(name: str) -> str:
"""Get a required environment variable. Raise error if not set."""
value = os.getenv(name)
if not value or not value.strip():
raise ValueError(
f"Required environment variable '{name}' is not set. "
f"Please set it before starting the application."
)
return value


class Settings:
"""Application settings loaded from environment variables."""

Expand All @@ -50,6 +61,7 @@ class Settings:
max_request_bytes: int = _int_env("MAX_REQUEST_BYTES", 1048576)
rate_limit_requests: int = _int_env("RATE_LIMIT_REQUESTS", 120)
rate_limit_window_seconds: int = _int_env("RATE_LIMIT_WINDOW_SECONDS", 60)
trust_proxy_headers: bool = _bool_env("TRUST_PROXY_HEADERS", False)
cache_enabled: bool = _bool_env("CACHE_ENABLED", True)
cache_ttl_seconds: int = _int_env("CACHE_TTL_SECONDS", 300)
cache_max_entries: int = _int_env("CACHE_MAX_ENTRIES", 100)
Expand All @@ -59,7 +71,7 @@ class Settings:
enable_docs: bool = _bool_env("ENABLE_DOCS", False)
public_root_info: bool = _bool_env("PUBLIC_ROOT_INFO", False)
database_url: str = os.getenv("DATABASE_URL", "sqlite:///./assistant.db")
jwt_secret: str = os.getenv("JWT_SECRET", "change-this-in-production-min-32-bytes")
jwt_secret: str = _required_env("JWT_SECRET")
jwt_algorithm: str = os.getenv("JWT_ALGORITHM", "HS256")
access_token_minutes: int = _int_env("ACCESS_TOKEN_MINUTES", 720)
llm_enabled: bool = _bool_env("LLM_ENABLED", False)
Expand Down
12 changes: 9 additions & 3 deletions backend/app/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,15 @@


def get_client_key(request: Request) -> str:
xff = request.headers.get("x-forwarded-for", "").split(",")[0].strip()
if xff:
return xff
"""Extract client IP for rate limiting.

Only uses X-Forwarded-For if TRUST_PROXY_HEADERS is enabled.
Falls back to direct connection IP if proxy headers are not trusted.
"""
if settings.trust_proxy_headers:
xff = request.headers.get("x-forwarded-for", "").split(",")[-1].strip()
if xff and xff != "unknown":
return xff
if request.client and request.client.host:
return request.client.host
return "unknown"
Expand Down
15 changes: 12 additions & 3 deletions backend/app/routers/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ async def analyze_zip(request: Request, file: UploadFile = File(...)):
results: list[dict] = []
skipped_files: list[str] = []
total_size = 0
MAX_PER_FILE_BYTES = 2 * 1024 * 1024 # 2MB per file

with archive:
members = [
Expand Down Expand Up @@ -307,14 +308,22 @@ async def analyze_zip(request: Request, file: UploadFile = File(...)):
)
continue

if total_size + info.file_size > MAX_ZIP_TOTAL_BYTES:
raw = archive.read(info)
decompressed_size = len(raw)

if decompressed_size > MAX_PER_FILE_BYTES:
raise HTTPException(
status_code=400,
detail=f"File '{safe_name}' exceeds 2MB limit after decompression",
)

if total_size + decompressed_size > MAX_ZIP_TOTAL_BYTES:
raise HTTPException(
status_code=400,
detail="ZIP source files exceed the 5MB total limit",
)

raw = archive.read(info)
total_size += len(raw)
total_size += decompressed_size

try:
code = raw.decode("utf-8")
Expand Down
31 changes: 25 additions & 6 deletions backend/app/routers/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
"""

from __future__ import annotations
from fastapi import APIRouter, HTTPException, Query
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field

from ..security import get_current_user
from ..models import User
from ..services import database

router = APIRouter()
Expand All @@ -29,8 +31,12 @@ class HistoryEntry(BaseModel):


@router.post("/", response_model=dict, status_code=201)
async def save_history(body: HistorySaveRequest):
async def save_history(
body: HistorySaveRequest,
current_user: User = Depends(get_current_user),
):
entry_id = await database.save_entry(
user_id=current_user.id,
code=body.code,
language=body.language,
score=body.score,
Expand All @@ -43,21 +49,34 @@ async def save_history(body: HistorySaveRequest):
async def get_history(
limit: int = Query(20, ge=1, le=100),
offset: int = Query(0, ge=0),
current_user: User = Depends(get_current_user),
):
return await database.get_entries(limit=limit, offset=offset)
return await database.get_entries(
user_id=current_user.id,
limit=limit,
offset=offset,
)


@router.get("/search", response_model=list[HistoryEntry])
async def search_history(
q: str = Query(..., min_length=1),
limit: int = Query(20, ge=1, le=100),
current_user: User = Depends(get_current_user),
):
return await database.search_entries(q=q, limit=limit)
return await database.search_entries(
user_id=current_user.id,
q=q,
limit=limit,
)


@router.delete("/{entry_id}", response_model=dict)
async def delete_history(entry_id: int):
deleted = await database.delete_entry(entry_id)
async def delete_history(
entry_id: int,
current_user: User = Depends(get_current_user),
):
deleted = await database.delete_entry(entry_id, user_id=current_user.id)
if not deleted:
raise HTTPException(status_code=404, detail="History entry not found.")
return {"id": entry_id, "status": "deleted"}
76 changes: 42 additions & 34 deletions backend/app/routers/share.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import logging
import secrets
from datetime import datetime, timedelta, timezone

from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
Expand All @@ -9,8 +11,26 @@
from ..models import SharedSnippet
from ..schemas import ShareCreateRequest, ShareRecord

logger = logging.getLogger(__name__)

router = APIRouter(prefix="/share", tags=["Share"])

MAX_SHARE_CODE_BYTES = 50 * 1024 # 50KB
SHARE_EXPIRY_DAYS = 30


def _cleanup_expired_shares(db: Session) -> int:
"""Delete expired share records and return count deleted."""
cutoff = datetime.now(timezone.utc) - timedelta(days=SHARE_EXPIRY_DAYS)
stmt = select(SharedSnippet).where(SharedSnippet.expiry_at < cutoff)
expired = db.execute(stmt).scalars().all()
for record in expired:
db.delete(record)
if expired:
db.commit()
logger.info(f"Cleaned up {len(expired)} expired share records")
return len(expired)


@router.post("/", response_model=ShareRecord)
def create_share(payload: ShareCreateRequest, db: Session = Depends(get_db)):
Expand All @@ -20,6 +40,14 @@ def create_share(payload: ShareCreateRequest, db: Session = Depends(get_db)):

_Base.metadata.create_all(bind=db.get_bind())

if len(payload.code) > MAX_SHARE_CODE_BYTES:
raise HTTPException(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail=f"Code content exceeds {MAX_SHARE_CODE_BYTES // 1024}KB limit",
)

_cleanup_expired_shares(db)

token = ""
for _ in range(5):
candidate = secrets.token_urlsafe(8)
Expand All @@ -31,10 +59,13 @@ def create_share(payload: ShareCreateRequest, db: Session = Depends(get_db)):
if not token:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Could not create share token")

now = datetime.now(timezone.utc)
record = SharedSnippet(
token=token,
code=payload.code,
result_json=json.dumps(payload.result),
created_at=now,
expiry_at=now + timedelta(days=SHARE_EXPIRY_DAYS),
)
db.add(record)
db.commit()
Expand All @@ -57,49 +88,26 @@ def get_share(token: str, db: Session = Depends(get_db)):

_Base.metadata.create_all(bind=db.get_bind())

_cleanup_expired_shares(db)

record = db.execute(select(SharedSnippet).where(SharedSnippet.token == token)).scalar_one_or_none()
if record is None:
# fallback: try raw SQL in case ORM mapping/env differences hide the record
from sqlalchemy import text
raw = db.execute(text("SELECT token, code, result_json, created_at FROM shares WHERE token = :t"), {"t": token}).first()
if raw is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Shared result not found or expired")

# parse created_at which may be string or datetime
token_val, code_val, result_json_val, created_at_val = raw
import datetime as _dt
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Shared result not found or expired")

created_at = created_at_val
if isinstance(created_at, str):
try:
created_at = _dt.datetime.fromisoformat(created_at)
except Exception:
try:
created_at = _dt.datetime.strptime(created_at, "%Y-%m-%d %H:%M:%S.%f")
except Exception:
created_at = None
now = datetime.now(timezone.utc)
expiry_at = record.expiry_at
if expiry_at.tzinfo is None:
expiry_at = expiry_at.replace(tzinfo=timezone.utc)

if created_at is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Shared result not found or expired")

if created_at.tzinfo is None:
created_at = created_at.replace(tzinfo=_dt.timezone.utc)

if created_at < _dt.datetime.now(_dt.timezone.utc) - _dt.timedelta(days=7):
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Shared result expired")

return ShareRecord(id=token_val, action="share", code=code_val, result=json.loads(result_json_val), created_at=created_at.isoformat())

# expire shares older than 7 days — normalize tzinfo if necessary
from datetime import datetime, timezone, timedelta
if expiry_at < now:
db.delete(record)
db.commit()
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Shared result expired")

created_at = record.created_at
if created_at.tzinfo is None:
created_at = created_at.replace(tzinfo=timezone.utc)

if created_at < datetime.now(timezone.utc) - timedelta(days=7):
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Shared result expired")

return ShareRecord(
id=record.token,
action="share",
Expand Down