diff --git a/api/app.py b/api/app.py index 2d10a8e1..d2c23752 100644 --- a/api/app.py +++ b/api/app.py @@ -6,7 +6,6 @@ from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse -from fastapi.staticfiles import StaticFiles from api import accounts, ai, image_tasks, register, system from api.support import resolve_web_asset, start_limited_account_watcher @@ -43,8 +42,6 @@ async def lifespan(_: FastAPI): app.include_router(image_tasks.create_router()) app.include_router(register.create_router()) app.include_router(system.create_router(app_version)) - if config.images_dir.exists(): - app.mount("/images", StaticFiles(directory=str(config.images_dir)), name="images") @app.get("/{full_path:path}", include_in_schema=False) async def serve_web(full_path: str): diff --git a/api/system.py b/api/system.py index 431cf696..45cf8536 100644 --- a/api/system.py +++ b/api/system.py @@ -10,7 +10,8 @@ from api.support import require_admin, require_identity, resolve_image_base_url from services.backup_service import BackupError, backup_service from services.config import config -from services.image_service import delete_images, download_images_zip, get_image_download_response, get_thumbnail_response, list_images +from services.image_service import delete_images, download_images_zip, get_image_download_response, get_image_response, get_thumbnail_response, list_images +from services.image_storage_service import ImageStorageError, image_storage_service from services.image_tags_service import delete_tag, get_all_tags, set_tags from services.log_service import log_service from services.proxy_service import test_proxy @@ -69,13 +70,20 @@ async def get_settings(authorization: str | None = Header(default=None)): @router.post("/api/settings") async def save_settings(body: SettingsUpdateRequest, authorization: str | None = Header(default=None)): require_admin(authorization) - return {"config": config.update(body.model_dump(mode="python"))} + try: + return {"config": config.update(body.model_dump(mode="python"))} + except ValueError as exc: + raise HTTPException(status_code=400, detail={"error": str(exc)}) from exc @router.get("/api/images") async def get_images(request: Request, start_date: str = "", end_date: str = "", authorization: str | None = Header(default=None)): require_admin(authorization) return list_images(resolve_image_base_url(request), start_date=start_date.strip(), end_date=end_date.strip()) + @router.get("/images/{image_path:path}", include_in_schema=False) + async def get_image(image_path: str): + return get_image_response(image_path) + @router.get("/image-thumbnails/{image_path:path}", include_in_schema=False) async def get_image_thumbnail(image_path: str): return get_thumbnail_response(image_path) @@ -135,6 +143,19 @@ async def test_backup_connection(authorization: str | None = Header(default=None except BackupError as exc: raise HTTPException(status_code=400, detail={"error": str(exc)}) from exc + @router.post("/api/image-storage/test") + async def test_image_storage_endpoint(authorization: str | None = Header(default=None)): + require_admin(authorization) + return {"result": await run_in_threadpool(image_storage_service.test_webdav)} + + @router.post("/api/image-storage/sync") + async def sync_image_storage_endpoint(authorization: str | None = Header(default=None)): + require_admin(authorization) + try: + return {"result": await run_in_threadpool(image_storage_service.sync_all)} + except ImageStorageError as exc: + raise HTTPException(status_code=400, detail={"error": str(exc)}) from exc + @router.get("/api/backups") async def get_backups(authorization: str | None = Header(default=None)): require_admin(authorization) diff --git a/services/backup_service.py b/services/backup_service.py index ba71724f..785cb632 100644 --- a/services/backup_service.py +++ b/services/backup_service.py @@ -16,6 +16,7 @@ from curl_cffi import requests from services.config import BASE_DIR, CONFIG_FILE, DATA_DIR, config, load_backup_state, save_backup_state +from services.image_storage_service import IMAGE_INDEX_FILE from services.image_tags_service import TAGS_FILE @@ -631,6 +632,7 @@ def _build_backup_archive(self, settings: dict[str, object], *, trigger: str) -> self._add_file_to_archive(archive, DATA_DIR / "logs.jsonl", "data/logs.jsonl") if include.get("image_tasks"): self._add_file_to_archive(archive, DATA_DIR / "image_tasks.json", "data/image_tasks.json") + self._add_file_to_archive(archive, IMAGE_INDEX_FILE, "data/image_index.json") if include.get("accounts_snapshot"): self._add_bytes_to_archive( archive, diff --git a/services/config.py b/services/config.py index 74b46bb7..05d68de5 100644 --- a/services/config.py +++ b/services/config.py @@ -27,6 +27,16 @@ "images": False, } +DEFAULT_IMAGE_STORAGE = { + "enabled": False, + "mode": "local", + "webdav_url": "", + "webdav_username": "", + "webdav_password": "", + "webdav_root_path": "chatgpt2api/images", + "public_base_url": "", +} + def _normalize_bool(value: object, default: bool = False) -> bool: if isinstance(value, str): @@ -85,6 +95,35 @@ def _normalize_backup_state(value: object) -> dict[str, object]: } +def _normalize_image_storage_settings(value: object) -> dict[str, object]: + source = value if isinstance(value, dict) else {} + mode = str(source.get("mode") or "local").strip().lower() + if mode not in {"local", "webdav", "both"}: + mode = "local" + enabled = _normalize_bool(source.get("enabled"), False) + if not enabled: + mode = "local" + root_path = str(source.get("webdav_root_path") or DEFAULT_IMAGE_STORAGE["webdav_root_path"]).strip().strip("/") + return { + "enabled": enabled, + "mode": mode, + "webdav_url": str(source.get("webdav_url") or "").strip().rstrip("/"), + "webdav_username": str(source.get("webdav_username") or "").strip(), + "webdav_password": str(source.get("webdav_password") or "").strip(), + "webdav_root_path": root_path or str(DEFAULT_IMAGE_STORAGE["webdav_root_path"]), + "public_base_url": str(source.get("public_base_url") or "").strip().rstrip("/"), + } + + +def _validate_image_storage_settings(settings: dict[str, object]) -> None: + if not _normalize_bool(settings.get("enabled"), False): + return + if not str(settings.get("webdav_url") or "").strip(): + raise ValueError("启用 WebDAV 图片存储后必须填写 WebDAV URL") + if not str(settings.get("webdav_password") or "").strip(): + raise ValueError("启用 WebDAV 图片存储后必须填写 WebDAV 密码") + + @dataclass(frozen=True) class LoadedSettings: auth_key: str @@ -285,6 +324,7 @@ def get(self) -> dict[str, object]: data["ai_review"] = self.ai_review data["global_system_prompt"] = self.global_system_prompt data["backup"] = self.get_backup_settings() + data["image_storage"] = self.get_image_storage_settings() data.pop("auth-key", None) return data @@ -296,6 +336,9 @@ def update(self, data: dict[str, object]) -> dict[str, object]: next_data.update(dict(data or {})) if "backup" in next_data: next_data["backup"] = _normalize_backup_settings(next_data.get("backup")) + if "image_storage" in next_data: + next_data["image_storage"] = _normalize_image_storage_settings(next_data.get("image_storage")) + _validate_image_storage_settings(next_data["image_storage"]) next_data.pop("backup_state", None) self.data = next_data self._save() @@ -304,6 +347,9 @@ def update(self, data: dict[str, object]) -> dict[str, object]: def get_backup_settings(self) -> dict[str, object]: return _normalize_backup_settings(self.data.get("backup")) + def get_image_storage_settings(self) -> dict[str, object]: + return _normalize_image_storage_settings(self.data.get("image_storage")) + def get_storage_backend(self) -> StorageBackend: """获取存储后端实例(单例)""" if self._storage_backend is None: diff --git a/services/image_service.py b/services/image_service.py index 4fdcbcc8..1a0842f5 100644 --- a/services/image_service.py +++ b/services/image_service.py @@ -2,14 +2,14 @@ import io import zipfile -from datetime import datetime from pathlib import Path from fastapi import HTTPException -from fastapi.responses import FileResponse +from fastapi.responses import FileResponse, Response from PIL import Image, ImageOps from services.config import config +from services.image_storage_service import image_storage_service from services.image_tags_service import load_tags, remove_tags THUMBNAIL_SIZE = (320, 320) @@ -46,6 +46,12 @@ def _safe_image_path(relative_path: str) -> Path: return path +def get_image_response(relative_path: str) -> FileResponse | Response: + if image_storage_service.has_local(relative_path): + return FileResponse(_safe_image_path(relative_path)) + return Response(content=image_storage_service.get_bytes(relative_path), media_type="image/png") + + def _thumbnail_path(relative_path: str) -> Path: rel = _safe_relative_path(relative_path) return config.image_thumbnails_dir / f"{rel}.png" @@ -64,15 +70,19 @@ def _image_dimensions(path: Path) -> tuple[int, int] | None: def ensure_thumbnail(relative_path: str) -> Path: - source = _safe_image_path(relative_path) target = _thumbnail_path(relative_path) - source_mtime = source.stat().st_mtime - if target.exists() and target.stat().st_mtime >= source_mtime: + source_mtime = 0.0 + source: Path | None = None + if image_storage_service.has_local(relative_path): + source = _safe_image_path(relative_path) + source_mtime = source.stat().st_mtime + if target.exists() and (not source_mtime or target.stat().st_mtime >= source_mtime): return target target.parent.mkdir(parents=True, exist_ok=True) try: - with Image.open(source) as image: + image_source = source if source is not None else io.BytesIO(image_storage_service.get_bytes(relative_path)) + with Image.open(image_source) as image: image = ImageOps.exif_transpose(image) if image.mode not in {"RGB", "RGBA"}: image = image.convert("RGBA" if "A" in image.getbands() else "RGB") @@ -90,52 +100,30 @@ def get_thumbnail_response(relative_path: str) -> FileResponse: def get_image_download_response(relative_path: str) -> FileResponse: - path = _safe_image_path(relative_path) - return FileResponse(path, filename=path.name) + if image_storage_service.has_local(relative_path): + path = _safe_image_path(relative_path) + return FileResponse(path, filename=path.name) + rel = _safe_relative_path(relative_path) + return Response( + content=image_storage_service.get_bytes(rel), + media_type="image/png", + headers={"Content-Disposition": f'attachment; filename="{Path(rel).name}"'}, + ) def cleanup_image_thumbnails() -> int: thumbnails_root = config.image_thumbnails_dir - images_root = config.images_dir removed = 0 for path in thumbnails_root.rglob("*"): if not path.is_file(): continue rel = path.relative_to(thumbnails_root).as_posix() - if not rel.endswith(".png") or not (images_root / rel[:-4]).exists(): + if not rel.endswith(".png") or not image_storage_service.exists(rel[:-4]): path.unlink() removed += 1 _cleanup_empty_dirs(thumbnails_root) return removed - -def _image_items(start_date: str = "", end_date: str = "") -> list[dict[str, object]]: - items = [] - root = config.images_dir - for path in root.rglob("*"): - if not path.is_file(): - continue - rel = path.relative_to(root).as_posix() - parts = rel.split("/") - day = "-".join(parts[:3]) if len(parts) >= 4 else datetime.fromtimestamp(path.stat().st_mtime).strftime("%Y-%m-%d") - if start_date and day < start_date: - continue - if end_date and day > end_date: - continue - dimensions = _image_dimensions(path) - items.append({ - "rel": rel, - "path": rel, - "name": path.name, - "date": day, - "size": path.stat().st_size, - "created_at": datetime.fromtimestamp(path.stat().st_mtime).strftime("%Y-%m-%d %H:%M:%S"), - **({"width": dimensions[0], "height": dimensions[1]} if dimensions else {}), - }) - items.sort(key=lambda item: str(item["created_at"]), reverse=True) - return items - - def list_images(base_url: str, start_date: str = "", end_date: str = "") -> dict[str, object]: config.cleanup_old_images() cleanup_image_thumbnails() @@ -143,11 +131,11 @@ def list_images(base_url: str, start_date: str = "", end_date: str = "") -> dict items = [ { **item, - "url": f"{base_url.rstrip('/')}/images/{item['path']}", + "url": str(item.get("url") or f"{base_url.rstrip('/')}/images/{item['path']}"), "thumbnail_url": thumbnail_url(base_url, str(item["path"])), "tags": all_tags.get(str(item["path"]), []), } - for item in _image_items(start_date, end_date) + for item in image_storage_service.list_items(base_url, start_date, end_date) ] groups: dict[str, list[dict[str, object]]] = {} for item in items: @@ -157,7 +145,10 @@ def list_images(base_url: str, start_date: str = "", end_date: str = "") -> dict def delete_images(paths: list[str] | None = None, start_date: str = "", end_date: str = "", all_matching: bool = False) -> dict[str, int]: root = config.images_dir.resolve() - targets = [str(item["path"]) for item in _image_items(start_date, end_date)] if all_matching else (paths or []) + targets = [ + str(item["path"]) + for item in image_storage_service.list_items("", start_date=start_date, end_date=end_date) + ] if all_matching else (paths or []) removed = 0 for item in targets: path = (root / item).resolve() @@ -165,13 +156,12 @@ def delete_images(paths: list[str] | None = None, start_date: str = "", end_date path.relative_to(root) except ValueError: continue - if path.is_file(): - path.unlink() - for thumbnail in (_thumbnail_path(item), config.image_thumbnails_dir / _safe_relative_path(item)): - if thumbnail.is_file(): - thumbnail.unlink() - remove_tags(item) + if image_storage_service.delete(item): removed += 1 + for thumbnail in (_thumbnail_path(item), config.image_thumbnails_dir / _safe_relative_path(item)): + if thumbnail.is_file(): + thumbnail.unlink() + remove_tags(item) _cleanup_empty_dirs(root) _cleanup_empty_dirs(config.image_thumbnails_dir) return {"removed": removed} @@ -186,12 +176,18 @@ def download_images_zip(paths: list[str]) -> io.BytesIO: for item in paths: rel = _safe_relative_path(item) path = (root / rel).resolve() + payload: bytes | None = None try: path.relative_to(root) except ValueError: continue - if not path.is_file(): - continue + if path.is_file(): + payload = path.read_bytes() + else: + try: + payload = image_storage_service.get_bytes(rel) + except Exception: + continue name = path.name if name in used_names: stem = path.stem @@ -201,7 +197,7 @@ def download_images_zip(paths: list[str]) -> io.BytesIO: counter += 1 name = f"{stem}_{counter}{suffix}" used_names.add(name) - zf.write(path, name) + zf.writestr(name, payload) added += 1 if added == 0: raise HTTPException(status_code=404, detail="no images found") diff --git a/services/image_storage_service.py b/services/image_storage_service.py new file mode 100644 index 00000000..9c9c50f6 --- /dev/null +++ b/services/image_storage_service.py @@ -0,0 +1,401 @@ +from __future__ import annotations + +import hashlib +import io +import json +import time +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from urllib.parse import quote, urlparse + +from curl_cffi import requests +from fastapi import HTTPException +from PIL import Image + +from services.config import DATA_DIR, config + +IMAGE_INDEX_FILE = DATA_DIR / "image_index.json" +IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp"} + + +class ImageStorageError(RuntimeError): + pass + + +@dataclass(frozen=True) +class StoredImage: + rel: str + url: str + storage: str + size: int + + +def _clean(value: object) -> str: + return str(value or "").strip() + + +def _now_iso() -> str: + return datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + +def _safe_relative_path(path: str) -> str: + value = str(path or "").strip().replace("\\", "/").lstrip("/") + if not value: + raise HTTPException(status_code=404, detail="image not found") + parts = Path(value).parts + if any(part in {"", ".", ".."} for part in parts): + raise HTTPException(status_code=404, detail="image not found") + return Path(*parts).as_posix() + + +def _image_dimensions(payload: bytes) -> tuple[int, int] | None: + try: + with Image.open(io.BytesIO(payload)) as image: + return image.size + except Exception: + return None + + +def _is_image_rel(path: str) -> bool: + try: + safe_rel = _safe_relative_path(path) + except HTTPException: + return False + return Path(safe_rel).suffix.lower() in IMAGE_EXTENSIONS + + +def _local_image_path(relative_path: str) -> Path: + rel = _safe_relative_path(relative_path) + root = config.images_dir.resolve() + path = (root / rel).resolve() + try: + path.relative_to(root) + except ValueError as exc: + raise HTTPException(status_code=404, detail="image not found") from exc + return path + + +def _read_json_object(path: Path) -> dict[str, object]: + if not path.exists(): + return {} + try: + data = json.loads(path.read_text(encoding="utf-8")) + except Exception: + return {} + return data if isinstance(data, dict) else {} + + +def _write_json_object(path: Path, data: dict[str, object]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = path.with_suffix(path.suffix + ".tmp") + tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") + tmp_path.replace(path) + + +class WebDAVClient: + def __init__(self, settings: dict[str, object]): + self.url = _clean(settings.get("webdav_url")).rstrip("/") + self.username = _clean(settings.get("webdav_username")) + self.password = _clean(settings.get("webdav_password")) + self.root_path = _clean(settings.get("webdav_root_path")).strip("/") + self.session = requests.Session() + + def _auth_kwargs(self) -> dict[str, object]: + return {"auth": (self.username, self.password)} if self.username or self.password else {} + + def _request(self, method: str, url: str, **kwargs): + response = self.session.request(method, url, timeout=30, **self._auth_kwargs(), **kwargs) + if response.status_code >= 400 and not (method == "MKCOL" and response.status_code in {405}): + raise ImageStorageError(f"WebDAV {method} failed: HTTP {response.status_code}") + return response + + def remote_url(self, rel: str = "") -> str: + parts = [part for part in [self.root_path, _safe_relative_path(rel) if rel else ""] if part] + encoded = "/".join(quote(part, safe="") for item in parts for part in item.split("/") if part) + return f"{self.url}/{encoded}" if encoded else self.url + + def ensure_dirs(self, rel: str) -> None: + parts = [part for part in [self.root_path, Path(_safe_relative_path(rel)).parent.as_posix()] if part and part != "."] + current = self.url + for item in "/".join(parts).split("/"): + if not item: + continue + current = f"{current}/{quote(item, safe='')}" + response = self.session.request("MKCOL", current, timeout=30, **self._auth_kwargs()) + if response.status_code in {201, 405}: + continue + if response.status_code >= 400: + raise ImageStorageError(f"WebDAV MKCOL failed: HTTP {response.status_code}") + + def put(self, rel: str, payload: bytes, content_type: str = "image/png") -> str: + self.ensure_dirs(rel) + url = self.remote_url(rel) + self._request("PUT", url, data=payload, headers={"Content-Type": content_type}) + return url + + def get(self, rel: str) -> bytes: + response = self._request("GET", self.remote_url(rel)) + return bytes(response.content) + + def delete(self, rel: str) -> bool: + response = self.session.request("DELETE", self.remote_url(rel), timeout=30, **self._auth_kwargs()) + if response.status_code in {200, 202, 204, 404}: + return response.status_code != 404 + raise ImageStorageError(f"WebDAV DELETE failed: HTTP {response.status_code}") + + def test(self) -> dict[str, object]: + if not self.url: + return {"ok": False, "status": 0, "error": "WebDAV URL is required"} + if urlparse(self.url).scheme not in {"http", "https"}: + return {"ok": False, "status": 0, "error": "invalid WebDAV URL"} + test_rel = ".chatgpt2api_webdav_test.txt" + try: + self.put(test_rel, b"chatgpt2api webdav test\n", content_type="text/plain") + self.delete(test_rel) + return {"ok": True, "status": 200, "error": None} + except ImageStorageError as exc: + return {"ok": False, "status": 0, "error": str(exc)} + except Exception as exc: + return {"ok": False, "status": 0, "error": str(exc) or exc.__class__.__name__} + finally: + self.session.close() + + +class ImageStorageService: + def __init__(self, index_file: Path = IMAGE_INDEX_FILE): + self.index_file = index_file + + def settings(self) -> dict[str, object]: + return config.get_image_storage_settings() + + def mode(self) -> str: + return _clean(self.settings().get("mode")) or "local" + + def _load_index(self) -> dict[str, dict[str, object]]: + raw = _read_json_object(self.index_file) + items = raw.get("items") + if not isinstance(items, dict): + return {} + return {str(key): value for key, value in items.items() if isinstance(value, dict)} + + def _load_clean_index(self) -> dict[str, dict[str, object]]: + items = self._load_index() + cleaned = {rel: item for rel, item in items.items() if _is_image_rel(rel)} + if len(cleaned) != len(items): + self._save_index(cleaned) + return cleaned + + def _save_index(self, items: dict[str, dict[str, object]]) -> None: + _write_json_object(self.index_file, {"items": items}) + + def _public_url(self, rel: str, base_url: str | None = None) -> str: + settings = self.settings() + public_base_url = _clean(settings.get("public_base_url")) + if public_base_url: + return f"{public_base_url.rstrip('/')}/{_safe_relative_path(rel)}" + return f"{(base_url or config.base_url).rstrip('/')}/images/{_safe_relative_path(rel)}" + + def make_relative_path(self, image_data: bytes) -> str: + file_hash = hashlib.md5(image_data).hexdigest() + filename = f"{int(time.time())}_{file_hash}.png" + relative_dir = Path(time.strftime("%Y"), time.strftime("%m"), time.strftime("%d")) + return f"{relative_dir.as_posix()}/{filename}" + + def save(self, image_data: bytes, base_url: str | None = None) -> StoredImage: + config.cleanup_old_images() + rel = self.make_relative_path(image_data) + mode = self.mode() + if mode not in {"local", "webdav", "both"}: + mode = "local" + stored_local = False + stored_webdav = False + remote_url = "" + + if mode in {"local", "both"}: + path = _local_image_path(rel) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(image_data) + stored_local = True + + if mode in {"webdav", "both"}: + remote_url = WebDAVClient(self.settings()).put(rel, image_data) + stored_webdav = True + + dimensions = _image_dimensions(image_data) + item = { + "rel": rel, + "path": rel, + "name": Path(rel).name, + "date": "-".join(rel.split("/")[:3]), + "size": len(image_data), + "created_at": _now_iso(), + "storage": "both" if stored_local and stored_webdav else ("webdav" if stored_webdav else "local"), + "local": stored_local, + "webdav": stored_webdav, + "remote_url": remote_url, + } + if dimensions: + item["width"], item["height"] = dimensions + items = self._load_clean_index() + items[rel] = item + self._save_index(items) + return StoredImage(rel=rel, url=self._public_url(rel, base_url), storage=str(item["storage"]), size=len(image_data)) + + def get_bytes(self, rel: str) -> bytes: + safe_rel = _safe_relative_path(rel) + if not _is_image_rel(safe_rel): + raise HTTPException(status_code=404, detail="image not found") + path = _local_image_path(safe_rel) + if path.is_file(): + return path.read_bytes() + item = self._load_clean_index().get(safe_rel, {}) + if item.get("webdav"): + return WebDAVClient(self.settings()).get(safe_rel) + raise HTTPException(status_code=404, detail="image not found") + + def exists(self, rel: str) -> bool: + safe_rel = _safe_relative_path(rel) + if not _is_image_rel(safe_rel): + return False + if _local_image_path(safe_rel).is_file(): + return True + item = self._load_clean_index().get(safe_rel, {}) + return bool(item.get("webdav")) + + def has_local(self, rel: str) -> bool: + safe_rel = _safe_relative_path(rel) + return _is_image_rel(safe_rel) and _local_image_path(safe_rel).is_file() + + def list_items(self, base_url: str, start_date: str = "", end_date: str = "") -> list[dict[str, object]]: + indexed = self._load_clean_index() + root = config.images_dir + changed = False + for path in root.rglob("*"): + if not path.is_file() or not _is_image_rel(path.name): + continue + rel = path.relative_to(root).as_posix() + if rel in indexed: + continue + dimensions = None + try: + dimensions = _image_dimensions(path.read_bytes()) + except Exception: + dimensions = None + indexed[rel] = { + "rel": rel, + "path": rel, + "name": path.name, + "date": "-".join(rel.split("/")[:3]) if len(rel.split("/")) >= 4 else datetime.fromtimestamp(path.stat().st_mtime).strftime("%Y-%m-%d"), + "size": path.stat().st_size, + "created_at": datetime.fromtimestamp(path.stat().st_mtime).strftime("%Y-%m-%d %H:%M:%S"), + "storage": "local", + "local": True, + "webdav": False, + **({"width": dimensions[0], "height": dimensions[1]} if dimensions else {}), + } + changed = True + + items: list[dict[str, object]] = [] + for rel, item in list(indexed.items()): + if not _is_image_rel(rel): + indexed.pop(rel, None) + changed = True + continue + local = _local_image_path(rel).is_file() + webdav = bool(item.get("webdav")) + if not local and not webdav: + indexed.pop(rel, None) + changed = True + continue + storage = "both" if local and webdav else ("webdav" if webdav else "local") + if item.get("local") != local or item.get("storage") != storage: + item = { + **item, + "local": local, + "storage": storage, + } + indexed[rel] = item + changed = True + day = str(item.get("date") or "") + if start_date and day < start_date: + continue + if end_date and day > end_date: + continue + items.append({ + **item, + "rel": rel, + "path": rel, + "url": self._public_url(rel, base_url), + }) + if changed: + self._save_index(indexed) + items.sort(key=lambda item: str(item.get("created_at") or ""), reverse=True) + return items + + def delete(self, rel: str) -> bool: + safe_rel = _safe_relative_path(rel) + removed = False + path = _local_image_path(safe_rel) + if path.is_file(): + path.unlink() + removed = True + items = self._load_clean_index() + item = items.get(safe_rel, {}) + if item.get("webdav"): + try: + removed = WebDAVClient(self.settings()).delete(safe_rel) or removed + except ImageStorageError: + if not removed: + raise + if safe_rel in items: + items.pop(safe_rel, None) + self._save_index(items) + return removed + + def sync_all(self) -> dict[str, int]: + settings = self.settings() + if self.mode() not in {"webdav", "both"}: + raise ImageStorageError("WebDAV 图片存储未启用") + uploaded = 0 + skipped = 0 + failed = 0 + items = self._load_clean_index() + client = WebDAVClient(settings) + for path in sorted(config.images_dir.rglob("*")): + if not path.is_file() or not _is_image_rel(path.name): + continue + rel = path.relative_to(config.images_dir).as_posix() + item = items.get(rel, {}) + if item.get("webdav"): + skipped += 1 + continue + try: + payload = path.read_bytes() + remote_url = client.put(rel, payload) + dimensions = _image_dimensions(payload) + items[rel] = { + **item, + "rel": rel, + "path": rel, + "name": path.name, + "date": "-".join(rel.split("/")[:3]) if len(rel.split("/")) >= 4 else datetime.fromtimestamp(path.stat().st_mtime).strftime("%Y-%m-%d"), + "size": len(payload), + "created_at": str(item.get("created_at") or datetime.fromtimestamp(path.stat().st_mtime).strftime("%Y-%m-%d %H:%M:%S")), + "storage": "both", + "local": True, + "webdav": True, + "remote_url": remote_url, + **({"width": dimensions[0], "height": dimensions[1]} if dimensions else {}), + } + uploaded += 1 + except Exception: + failed += 1 + self._save_index(items) + return {"uploaded": uploaded, "skipped": skipped, "failed": failed} + + def test_webdav(self) -> dict[str, object]: + return WebDAVClient(self.settings()).test() + + +image_storage_service = ImageStorageService() diff --git a/services/protocol/conversation.py b/services/protocol/conversation.py index 1cddea84..bc7a1500 100644 --- a/services/protocol/conversation.py +++ b/services/protocol/conversation.py @@ -1,18 +1,17 @@ from __future__ import annotations import base64 -import hashlib import json import re import time from dataclasses import dataclass, field -from pathlib import Path from typing import Any, Iterable, Iterator import tiktoken from services.account_service import account_service from services.config import config +from services.image_storage_service import image_storage_service from services.openai_backend_api import OpenAIBackendAPI from utils.helper import IMAGE_MODELS, extract_image_from_message_content from utils.log import logger @@ -67,14 +66,7 @@ def encode_images(images: Iterable[tuple[bytes, str, str]]) -> list[str]: def save_image_bytes(image_data: bytes, base_url: str | None = None) -> str: - config.cleanup_old_images() - file_hash = hashlib.md5(image_data).hexdigest() - filename = f"{int(time.time())}_{file_hash}.png" - relative_dir = Path(time.strftime("%Y"), time.strftime("%m"), time.strftime("%d")) - file_path = config.images_dir / relative_dir / filename - file_path.parent.mkdir(parents=True, exist_ok=True) - file_path.write_bytes(image_data) - return f"{(base_url or config.base_url)}/images/{relative_dir.as_posix()}/{filename}" + return image_storage_service.save(image_data, base_url).url def message_text(content: Any) -> str: diff --git a/test/test_image_storage_service.py b/test/test_image_storage_service.py new file mode 100644 index 00000000..ea926de9 --- /dev/null +++ b/test/test_image_storage_service.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +import tempfile +import unittest +from pathlib import Path +from unittest import mock + +from PIL import Image + +from services.image_storage_service import ImageStorageService + + +def png_bytes() -> bytes: + path = Path(tempfile.gettempdir()) / "chatgpt2api-test-image.png" + Image.new("RGB", (2, 2), color=(255, 0, 0)).save(path, format="PNG") + return path.read_bytes() + + +class FakeWebDAVClient: + uploaded: dict[str, bytes] = {} + deleted: list[str] = [] + + def __init__(self, _settings): + pass + + def put(self, rel: str, payload: bytes) -> str: + self.uploaded[rel] = payload + return f"https://dav.example.test/{rel}" + + def get(self, rel: str) -> bytes: + return self.uploaded[rel] + + def delete(self, rel: str) -> bool: + self.deleted.append(rel) + self.uploaded.pop(rel, None) + return True + + def test(self) -> dict[str, object]: + self.put(".chatgpt2api_webdav_test.txt", b"chatgpt2api webdav test\n") + self.delete(".chatgpt2api_webdav_test.txt") + return {"ok": True, "status": 200, "error": None} + + +class ImageStorageServiceTests(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.TemporaryDirectory() + self.addCleanup(self.tmp.cleanup) + self.data_dir = Path(self.tmp.name) + self.images_dir = self.data_dir / "images" + self.settings = { + "enabled": False, + "mode": "local", + "webdav_url": "", + "webdav_username": "", + "webdav_password": "", + "webdav_root_path": "chatgpt2api/images", + "public_base_url": "", + } + self.config_patcher = mock.patch("services.image_storage_service.config") + self.mock_config = self.config_patcher.start() + self.addCleanup(self.config_patcher.stop) + self.mock_config.images_dir = self.images_dir + self.mock_config.base_url = "http://app.test" + self.mock_config.cleanup_old_images.return_value = 0 + self.mock_config.get_image_storage_settings.side_effect = lambda: dict(self.settings) + FakeWebDAVClient.uploaded = {} + FakeWebDAVClient.deleted = [] + + def service(self) -> ImageStorageService: + return ImageStorageService(self.data_dir / "image_index.json") + + def test_local_mode_saves_to_local_directory(self): + stored = self.service().save(png_bytes(), "http://app.test") + + self.assertEqual(stored.storage, "local") + self.assertTrue((self.images_dir / stored.rel).is_file()) + self.assertEqual(stored.url, f"http://app.test/images/{stored.rel}") + + def test_webdav_mode_uploads_without_local_file(self): + self.settings.update({ + "enabled": True, + "mode": "webdav", + "webdav_url": "https://dav.example.test", + "webdav_password": "secret", + }) + with mock.patch("services.image_storage_service.WebDAVClient", FakeWebDAVClient): + stored = self.service().save(png_bytes(), "http://app.test") + payload = self.service().get_bytes(stored.rel) + + self.assertEqual(stored.storage, "webdav") + self.assertFalse((self.images_dir / stored.rel).exists()) + self.assertIn(stored.rel, FakeWebDAVClient.uploaded) + self.assertEqual(payload, FakeWebDAVClient.uploaded[stored.rel]) + + def test_list_items_ignores_non_image_files(self): + image = png_bytes() + image_path = self.images_dir / "2026" / "05" / "07" / "sample.png" + image_path.parent.mkdir(parents=True, exist_ok=True) + image_path.write_bytes(image) + (self.images_dir / ".DS_Store").write_text("not an image", encoding="utf-8") + (self.images_dir / "2026" / ".DS_Store").write_text("not an image", encoding="utf-8") + + items = self.service().list_items("http://app.test") + + self.assertEqual([item["rel"] for item in items], ["2026/05/07/sample.png"]) + self.assertEqual(items[0]["storage"], "local") + + def test_both_mode_saves_to_local_and_webdav(self): + self.settings.update({ + "enabled": True, + "mode": "both", + "webdav_url": "https://dav.example.test", + "webdav_password": "secret", + "public_base_url": "https://cdn.example.test/images", + }) + with mock.patch("services.image_storage_service.WebDAVClient", FakeWebDAVClient): + stored = self.service().save(png_bytes(), "http://app.test") + + self.assertEqual(stored.storage, "both") + self.assertTrue((self.images_dir / stored.rel).is_file()) + self.assertIn(stored.rel, FakeWebDAVClient.uploaded) + self.assertEqual(stored.url, f"https://cdn.example.test/images/{stored.rel}") + + def test_test_webdav_writes_and_deletes_probe_file(self): + self.settings.update({ + "enabled": True, + "mode": "webdav", + "webdav_url": "https://dav.example.test", + "webdav_password": "secret", + }) + with mock.patch("services.image_storage_service.WebDAVClient", FakeWebDAVClient): + result = self.service().test_webdav() + + self.assertTrue(result["ok"]) + self.assertIn(".chatgpt2api_webdav_test.txt", FakeWebDAVClient.deleted) + + +if __name__ == "__main__": + unittest.main() diff --git a/web/src/app/image-manager/page.tsx b/web/src/app/image-manager/page.tsx index a8e94973..e3f51cf8 100644 --- a/web/src/app/image-manager/page.tsx +++ b/web/src/app/image-manager/page.tsx @@ -18,6 +18,16 @@ import { useAuthGuard } from "@/lib/use-auth-guard"; const LONG_PRESS_MS = 800; +function storageBadge(item: ManagedImage) { + if (item.local && item.webdav) { + return { label: "双端", className: "border-sky-200 bg-sky-50 text-sky-700" }; + } + if (item.webdav || item.storage === "webdav") { + return { label: "WebDAV", className: "border-violet-200 bg-violet-50 text-violet-700" }; + } + return { label: "本机", className: "border-stone-200 bg-stone-50 text-stone-600" }; +} + function formatSize(size: number) { return size > 1024 * 1024 ? `${(size / 1024 / 1024).toFixed(2)} MB` : `${Math.ceil(size / 1024)} KB`; } @@ -364,6 +374,7 @@ function ImageManagerContent() {