diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..21bfe17d --- /dev/null +++ b/.gitattributes @@ -0,0 +1,4 @@ +* text=auto +*.yaml text eol=lf +*.yml text eol=lf +*.md text eol=lf diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml index f7bdd844..753b6e9e 100644 --- a/.github/workflows/docker-publish.yml +++ b/.github/workflows/docker-publish.yml @@ -2,6 +2,8 @@ name: Publish Docker Image on: push: + branches: + - main tags: - "v*" workflow_dispatch: @@ -41,9 +43,10 @@ jobs: with: images: ghcr.io/${{ github.repository_owner }}/${{ env.IMAGE_NAME }} tags: | + type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }} type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/v') }} type=ref,event=tag - type=sha + type=sha,prefix=sha- type=semver,pattern={{version}} type=semver,pattern={{major}}.{{minor}} diff --git a/9router-test b/9router-test new file mode 160000 index 00000000..7ad538bc --- /dev/null +++ b/9router-test @@ -0,0 +1 @@ +Subproject commit 7ad538bcf248a563e78f984bf64e66286058917f diff --git a/Dockerfile.lite b/Dockerfile.lite new file mode 100644 index 00000000..12802634 --- /dev/null +++ b/Dockerfile.lite @@ -0,0 +1,25 @@ +FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim + +LABEL org.opencontainers.image.description="chatgpt2api Lite - OpenAI-compatible proxy for ChatGPT Web API (no Web UI)" + +WORKDIR /app + +# Install system dependencies (only openssl for curl-cffi) +RUN apt-get update && apt-get install -y --no-install-recommends \ + openssl \ + && rm -rf /var/lib/apt/lists/* + +# Install Python dependencies +COPY pyproject.toml uv.lock ./ +RUN uv sync --frozen --no-dev --no-install-project + +# Copy application code +COPY main_lite.py ./main.py +COPY VERSION ./ +COPY api ./api +COPY services ./services +COPY utils ./utils + +EXPOSE 80 + +CMD ["uv", "run", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80", "--access-log"] diff --git a/README.md b/README.md index ac59ba33..d764c0f3 100644 --- a/README.md +++ b/README.md @@ -1,320 +1,304 @@ -
ChatGPT2API 主要是对 ChatGPT 官网相关能力进行逆向整理与封装,提供面向 ChatGPT 图片生成、图片编辑、多图组图编辑场景的 OpenAI 兼容图片 API / 代理,并集成在线画图、号池管理、多种账号导入方式与 Docker 自托管部署能力。
+## Mục lục -> [!WARNING] -> 免责声明: -> -> 本项目涉及对 ChatGPT 官网文本生成、图片生成与图片编辑等相关接口的逆向研究,仅供个人学习、技术研究与非商业性技术交流使用。 -> -> - 严禁将本项目用于任何商业用途、盈利性使用、批量操作、自动化滥用或规模化调用。 -> - 严禁将本项目用于破坏市场秩序、恶意竞争、套利倒卖、二次售卖相关服务,以及任何违反 OpenAI 服务条款或当地法律法规的行为。 -> - 严禁将本项目用于生成、传播或协助生成违法、暴力、色情、未成年人相关内容,或用于诈骗、欺诈、骚扰等非法或不当用途。 -> - 使用者应自行承担全部风险,包括但不限于账号被限制、临时封禁或永久封禁以及因违规使用等所导致的法律责任。 -> - 使用本项目即视为你已充分理解并同意本免责声明全部内容;如因滥用、违规或违法使用造成任何后果,均由使用者自行承担。 +- [Cài đặt qua Home Assistant Addon](#cai-dat-qua-home-assistant-addon) +- [Cài đặt qua Docker](#cai-dat-qua-docker) +- [Cài đặt qua Docker Compose / Portainer](#cai-dat-qua-docker-compose--portainer) +- [Cài đặt trực tiếp (source)](#cai-dat-truc-tiep-source) +- [Cấu hình Home Assistant](#cau-hinh-home-assistant) +- [Thêm tài khoản](#them-tai-khoan) +- [Model](#model) +- [Tìm kiếm (Search)](#tim-kiem-search) +- [API Endpoints](#api-endpoints) +- [Troubleshooting](#troubleshooting) -> [!IMPORTANT] -> 本项目基于对 ChatGPT 官网相关能力的逆向研究实现,存在账号受限、临时封禁或永久封禁的风险。请勿使用你自己的重要账号、常用账号或高价值账号进行测试。 +--- -> [!CAUTION] -> 旧版本存在已知漏洞,请尽快升级到最新版本。公网部署时请尽量不要放置敏感信息,并自行做好访问控制与隔离。 +## Cài đặt qua Home Assistant Addon -## 快速开始 +[](https://my.home-assistant.io/redirect/supervisor_add_addon_repository/?repository_url=https%3A%2F%2Fgithub.com%2FTriTue2011%2Fhas-addons) -已发布镜像支持 `linux/amd64` 与 `linux/arm64`,在 x86 服务器和 Apple Silicon / ARM Linux 设备上都会自动拉取匹配架构的版本。 +**Bước 1:** Vào **Settings → Add-ons → Add-on Store** -### Docker 运行 +**Bước 2:** Nhấn **⋮** (góc phải trên) → **Repositories** -```bash -git clone git@github.com:basketikun/chatgpt2api.git -cd chatgpt2api -docker compose up -d -``` +**Bước 3:** Thêm URL: `https://github.com/TriTue2011/has-addons` → **Add** -启动前请先在 `config.json` 中设置 `auth-key`,也可以在 `docker-compose.yml` 中通过 `CHATGPT2API_AUTH_KEY` 覆盖。 +**Bước 4:** Tìm **chatgpt2api** trong store → **Install** -- Web 面板:`http://localhost:3000` -- API 地址:`http://localhost:3000/v1` -- 数据目录:`./data` +**Bước 5:** Tab **Configuration** → sửa `auth_key` (mặc định: `sk-chatgpt2api`) → **Save** -### 本地开发 +**Bước 6:** **Start** → mở Web UI tại `http://HA_IP:3030` -启动后端: +**Bước 7:** Đăng nhập bằng `auth_key` đã đặt -```bash -git clone git@github.com:basketikun/chatgpt2api.git -cd chatgpt2api -uv sync -uv run main.py -``` +Sau đó vào [Cấu hình Home Assistant](#cau-hinh-home-assistant). + +--- -启动前端: +## Cài đặt qua Docker ```bash -cd chatgpt2api/web -bun install -bun run dev +docker run -d \ + --name chatgpt2api \ + --restart unless-stopped \ + -p 3030:80 \ + -v chatgpt2api_data:/app/data \ + -e CHATGPT2API_AUTH_KEY=your_secret_key_here \ + ghcr.io/tritue2011/chatgpt2api:latest ``` -### 存储后端配置 +Sau khi chạy: +- Web UI: `http://IP:3030` +- API: `http://IP:3030/v1/chat/completions` +- Đăng nhập Web UI bằng `your_secret_key_here` -支持通过环境变量 `STORAGE_BACKEND` 切换存储方式: +> **Quan trọng**: Volume `chatgpt2api_data` lưu TOÀN Bộ dữ liệu: accounts, API keys (Gemini, NVIDIA, DeepSeek...), custom providers, model settings, combos, ảnh, backup. **Không được xóa volume này** nếu không muốn mất hết cài đặt. -- `json` - 本地 JSON 文件(默认) -- `sqlite` - 本地 SQLite 数据库 -- `postgres` - 外部 PostgreSQL(需配置 `DATABASE_URL`) -- `git` - Git 私有仓库(需配置 `GIT_REPO_URL` 和 `GIT_TOKEN`) +--- -示例:使用 PostgreSQL +## Cài đặt qua Docker Compose / Portainer ```yaml -environment: - - STORAGE_BACKEND=postgres - - DATABASE_URL=postgresql://user:password@host:5432/dbname +services: + chatgpt2api: + image: ghcr.io/tritue2011/chatgpt2api:latest + container_name: chatgpt2api + restart: unless-stopped + ports: + - "3030:80" + volumes: + # [QUAN TRỌNG] Bind mount → thư mục thật trên host, không bao giờ mất khi build lại + - ./chatgpt2api-data:/app/data + environment: + # [BẮT BUỘC] Đổi thành key bảo mật của bạn + CHATGPT2API_AUTH_KEY: your_secret_key_here + STORAGE_BACKEND: json + +# Không cần khai báo volumes ở cuối nếu dùng bind mount ``` -## 功能 +**Portainer:** Stacks → Add Stack → Web Editor → paste nội dung trên → Deploy. + +### Cập nhật lên phiên bản mới (giữ nguyên dữ liệu) + +```bash +docker compose pull +docker compose up -d +``` + +Dữ liệu trong `./chatgpt2api-data/` (thư mục thật trên host) **không bao giờ mất** khi pull image mới. + +--- + +## Cài đặt trực tiếp (source) -### API 兼容能力 +Yêu cầu: Python 3.12+, Node.js 20+, Git -- 兼容 `POST /v1/images/generations` 图片生成接口 -- 兼容 `POST /v1/images/edits` 图片编辑接口 -- 兼容面向图片场景的 `POST /v1/chat/completions` -- 兼容面向图片场景的 `POST /v1/responses` -- `GET /v1/models` 返回 `gpt-image-2`、`codex-gpt-image-2`、`auto`、`gpt-5`、`gpt-5-1`、`gpt-5-2`、`gpt-5-3`、`gpt-5-3-mini`、 - `gpt-5-mini` -- 支持通过 `n` 返回多张生成结果 -- 支持 Codex 中的画图接口逆向,仅 `Plus` / `Team` / `Pro` 订阅可用,模型别名为 `codex-gpt-image-2`,如有需要可自行在其他场景映射回 - `gpt-image-2`,用于和官网画图区分;也就意味着同一账号会同时有官网和 Codex 两份生图额度 +```bash +git clone https://github.com/TriTue2011/chatgpt2api +cd chatgpt2api -### 在线画图功能 +# Cài Python dependencies +pip install uv +uv sync -- 内置在线画图工作台,支持生成、图片编辑与多图组图编辑 -- 支持 `gpt-image-2`、`codex-gpt-image-2`、`auto`、`gpt-5`、`gpt-5-1`、`gpt-5-2`、`gpt-5-3`、`gpt-5-3-mini`、`gpt-5-mini` 模型选择 -- 编辑模式支持参考图上传 -- 前端支持多图生成交互 -- 本地保存图片会话历史,支持回看、删除和清空 -- 支持服务端缓存图片URL +# Build web UI +cd web && npm install && npm run build && cd .. -### 号池管理功能 +# Chạy +cp .env.example .env +# Sửa CHATGPT2API_AUTH_KEY trong .env +uv run uvicorn main:app --host 0.0.0.0 --port 3030 +``` -- 自动刷新账号邮箱、类型、额度和恢复时间 -- 轮询可用账号执行图片生成与图片编辑 -- 遇到 Token 失效类错误时自动剔除无效 Token -- 定时检查限流账号并自动刷新 -- 支持网页端配置全局 HTTP / HTTPS / SOCKS5 / SOCKS5H 代理 -- 支持搜索、筛选、批量刷新、导出、手动编辑和清理账号 -- 支持四种导入方式:本地 CPA JSON 文件导入、远程 CPA 服务器导入、`sub2api` 服务器导入、`access_token` 导入 -- 支持在设置页配置 `sub2api` 服务器,筛选并批量导入其中的 OpenAI OAuth 账号 +--- -### 实验性 / 规划中 +## Cấu hình Home Assistant -- `/v1/complete` 文本补全与流式输出已实现,但仍在测试,目前会出现对话重复的问题,请谨慎测试使用 -- 详细状态说明见:[功能清单](./docs/feature-status.en.md) +Sau khi chatgpt2api đã chạy, cấu hình HA để dùng nó làm conversation agent. -## Screenshots +### Dùng OpenAI Conversation (có sẵn trong HA) -文生图界面: +**Settings → Devices & Services → Add Integration → OpenAI Conversation:** - +| Field | Value | +|-------|-------| +| Base URL | `http://localhost:3030/v1` | +| API Key | Key đã đặt ở bước trên | +| Model | `ha-agent` | -编辑图: +### Dùng với hass_local_openai_llm - +```yaml +# configuration.yaml +openai_llm: + - name: chatgpt2api + base_url: http://localhost:3030 + api_key: your_secret_key_here + model: ha-agent +``` -Cherry Studio 中使用,支持作为绘图接口接入: +### Voice Pipeline - +**Settings → Voice Assistants →** chọn pipeline → **Conversation Agent** → chọn agent đã config. -号池管理: +--- - +## Thêm tài khoản -New Api 接入: +### Token ChatGPT Web (chat + tạo ảnh) - +1. Mở browser, đăng nhập https://chatgpt.com +2. Vào https://chatgpt.com/api/auth/session +3. Copy giá trị `accessToken` +4. Web UI → **Tài khoản → Nhập tài khoản → Nhập Access Token** → paste +5. Token JWT (bắt đầu `eyJ`) tự động dùng được cho cả chat (`cx/auto`) và tạo ảnh (`gpt-image-2`) -## API +### Token Codex OAuth từ 9router (chat không giới hạn) -所有 AI 接口都需要请求头: +1. Web UI → **Sao lưu** → kéo thả file backup `.json` từ 9router +2. 10 token tự động thêm vào pool +3. Dùng model `cx/auto` — không giới hạn 24KB, native tool calling -```http -Authorization: BearerGET /v1/modelsPOST /v1/images/generationsPOST /v1/images/editsPOST /v1/chat/completionsPOST /v1/responsesBạn có thể đóng tab này.
+ + + """) + except Exception as exc: + return HTMLResponse(content=f""" + +Copy URL này và dùng API exchange thủ công.
+ + """, status_code=400) + + @router.get("/api/oauth/codex/callback") + async def codex_oauth_callback(code: str = "", state: str = ""): + """Handle Codex OAuth callback — exchange code for token.""" + if not code or not state: + raise HTTPException(status_code=400, detail={"error": "Missing code or state"}) + try: + result = exchange_codex_code(code, state) + return HTMLResponse(content=f""" + +Bạn có thể đóng tab này.
+ + + """) + except Exception as exc: + raise HTTPException(status_code=400, detail={"error": str(exc)}) + + class CodexExchangeRequest(BaseModel): + redirect_url: str = "" + + @router.post("/api/oauth/codex/exchange") + async def codex_oauth_exchange(body: CodexExchangeRequest, authorization: str | None = Header(default=None)): + """Exchange Codex OAuth code manually — user pastes redirect URL.""" + require_admin(authorization) + from urllib.parse import urlparse, parse_qs + url = (body.redirect_url or "").strip() + if not url: + raise HTTPException(status_code=400, detail={"error": "redirect_url is required"}) + # Replace localhost with actual host if needed + parsed = urlparse(url) + params = parse_qs(parsed.query) + code = (params.get("code") or [""])[0] + state = (params.get("state") or [""])[0] + if not code or not state: + raise HTTPException(status_code=400, detail={"error": "URL không chứa code và state. Copy TOÀN Bộ URL sau khi redirect."}) + try: + result = exchange_codex_code(code, state) + return result + except Exception as exc: + raise HTTPException(status_code=400, detail={"error": str(exc)}) + + @router.get("/api/oauth/session-url") + async def get_session_url(authorization: str | None = Header(default=None)): + """Return chatgpt.com session URL for getting image token.""" + require_admin(authorization) + return {"url": get_chatgpt_session_url()} + + @router.post("/api/oauth/detect-token") + async def detect_token(body: dict, authorization: str | None = Header(default=None)): + """Detect token type (codex vs google).""" + require_admin(authorization) + token = str(body.get("token") or "").strip() + if not token: + raise HTTPException(status_code=400, detail={"error": "token is required"}) + return {"type": detect_token_type(token)} + + # ── Custom Providers ── + + @router.get("/api/v1/custom-providers") + async def list_custom_providers(authorization: str | None = Header(default=None)): + """Lấy danh sách custom providers.""" + require_admin(authorization) + from services.providers.custom_openai import get_custom_providers + return {"custom_providers": get_custom_providers()} + + @router.post("/api/v1/custom-providers") + async def save_custom_provider(body: dict, authorization: str | None = Header(default=None)): + """Thêm hoặc cập nhật một custom provider.""" + require_admin(authorization) + provider = body.get("provider") or {} + if not isinstance(provider, dict): + raise HTTPException(status_code=400, detail={"error": "provider object is required"}) + + provider_id = str(provider.get("prefix") or provider.get("name") or "").strip().lower().replace(" ", "_") + if not provider_id: + raise HTTPException(status_code=400, detail={"error": "provider prefix or name is required"}) + + base_url = str(provider.get("base_url") or "").strip().rstrip("/") + api_key = str(provider.get("api_key") or "").strip() + api_keys = provider.get("api_keys") or [] + if not isinstance(api_keys, list): + api_keys = [] + api_keys = [k.strip() for k in api_keys if k.strip()] + # Support api_key + api_keys combination + if api_key and api_key not in api_keys: + api_keys.insert(0, api_key) + name = str(provider.get("name") or provider_id).strip() + enabled = provider.get("enabled", True) + prefix = str(provider.get("prefix") or provider_id).strip().lower().replace(" ", "_") + + # Validate: test connection with first key + test_key = api_keys[0] if api_keys else api_key + if not test_key: + raise HTTPException(status_code=400, detail={"error": "At least one API key is required"}) + try: + from curl_cffi import requests as cffi_req + resp = cffi_req.get( + f"{base_url}/v1/models", + headers={"Authorization": f"Bearer {test_key}"}, + timeout=10, + ) + if resp.status_code >= 400: + raise HTTPException( + status_code=400, + detail={"error": f"Cannot connect to {base_url}: HTTP {resp.status_code}"}, + ) + except Exception as exc: + if not isinstance(exc, HTTPException): + raise HTTPException( + status_code=400, + detail={"error": f"Cannot connect to {base_url}: {exc}"}, + ) + raise + + # Save to config + custom_providers = dict(config.data.get("custom_providers") or {}) + if not isinstance(custom_providers, dict): + custom_providers = {} + custom_providers[provider_id] = { + "name": name, + "base_url": base_url, + "api_key": api_keys[0] if api_keys else "", + "api_keys": api_keys, + "prefix": prefix, + "enabled": enabled, + } + config.data["custom_providers"] = custom_providers + config._save() + from services.protocol.openai_v1_models import invalidate_models_cache + invalidate_models_cache() + + return { + "custom_providers": custom_providers, + "saved": True, + "provider_id": provider_id, + } + + @router.delete("/api/v1/custom-providers/{provider_id}") + async def delete_custom_provider(provider_id: str, authorization: str | None = Header(default=None)): + """Xóa một custom provider.""" + require_admin(authorization) + custom_providers = dict(config.data.get("custom_providers") or {}) + if not isinstance(custom_providers, dict): + custom_providers = {} + if provider_id in custom_providers: + del custom_providers[provider_id] + config.data["custom_providers"] = custom_providers + config._save() + from services.protocol.openai_v1_models import invalidate_models_cache + invalidate_models_cache() + return {"deleted": True, "provider_id": provider_id} + raise HTTPException(status_code=404, detail={"error": f"provider '{provider_id}' not found"}) + + # ── Model Settings ── + + @router.get("/api/v1/model-settings") + async def get_model_settings(authorization: str | None = Header(default=None)): + """Lấy cấu hình model (enabled models + defaults per provider).""" + require_admin(authorization) + ms = config.data.get("model_settings") or {} + if not isinstance(ms, dict): + ms = {} + return { + "model_settings": { + "enabled_models": ms.get("enabled_models") or {}, + "default_models": ms.get("default_models") or {}, + } + } + + @router.get("/api/v1/available-models") + async def get_available_models(authorization: str | None = Header(default=None), refresh: str = ""): + """Lấy toàn bộ model có sẵn từ cache (nhanh). Thêm ?refresh=true để tải lại.""" + require_admin(authorization) + from services.protocol.openai_v1_models import list_models + force = refresh.lower() == "true" + result = list_models(force_refresh=force) + # Group models by owned_by + grouped: dict[str, list[str]] = {} + for item in result.get("data", []): + owner = str(item.get("owned_by") or "chatgpt") + mid = str(item.get("id") or "").strip() + if mid: + if owner not in grouped: + grouped[owner] = [] + grouped[owner].append(mid) + # Sort each group + for owner in grouped: + grouped[owner].sort() + return {"providers": grouped} + + @router.get("/api/v1/models-with-capabilities") + async def get_models_with_capabilities(authorization: str | None = Header(default=None)): + """Lấy danh sách model kèm phân loại capability (chat/vision/image).""" + require_admin(authorization) + from utils.helper import classify_model_capability, get_model_capability_label + + # Get enabled models from model_settings + ms = config.data.get("model_settings") or {} + if not isinstance(ms, dict): + ms = {} + enabled_by_provider = ms.get("enabled_models") or {} + if not isinstance(enabled_by_provider, dict): + enabled_by_provider = {} + + # Fetch all available models + from services.protocol.openai_v1_models import list_models + result = list_models() + all_models = result.get("data", []) + + enriched: list[dict] = [] + for model in all_models: + mid = str(model.get("id") or "").strip() + if not mid: + continue + caps = classify_model_capability(mid) + enriched.append({ + "id": mid, + "owned_by": str(model.get("owned_by") or ""), + "capability": caps[0] if caps else "chat", # Primary capability + "capabilities": caps, # All capabilities + "capability_labels": [get_model_capability_label(c) for c in caps], + "enabled": _is_model_enabled(mid, enabled_by_provider), + }) + + # Sort: chat first, then vision, then image, then video + def _sort_key(m): + caps = m.get("capabilities", ["chat"]) + if "video" in caps: return 3 + if "image" in caps: return 2 + if "vision" in caps: return 1 + return 0 + enriched.sort(key=lambda m: (_sort_key(m), m["id"])) + + return { + "models": enriched, + "counts": { + "chat": sum(1 for m in enriched if "chat" in (m.get("capabilities") or ["chat"])), + "vision": sum(1 for m in enriched if "vision" in (m.get("capabilities") or [])), + "image": sum(1 for m in enriched if "image" in (m.get("capabilities") or [])), + "video": sum(1 for m in enriched if "video" in (m.get("capabilities") or [])), + }, + } + + @router.post("/api/v1/model-settings") + async def save_model_settings(body: dict, authorization: str | None = Header(default=None)): + """Lưu cấu hình model.""" + require_admin(authorization) + model_settings = body.get("model_settings") + if not isinstance(model_settings, dict): + raise HTTPException(status_code=400, detail={"error": "model_settings is required"}) + enabled = model_settings.get("enabled_models") + defaults = model_settings.get("default_models") + if not isinstance(enabled, dict): + enabled = {} + if not isinstance(defaults, dict): + defaults = {} + config.data["model_settings"] = { + "enabled_models": enabled, + "default_models": defaults, + } + config._save() + from services.protocol.openai_v1_models import invalidate_models_cache + invalidate_models_cache() + return {"model_settings": config.data["model_settings"], "saved": True} + return router diff --git a/api/veo_video.py b/api/veo_video.py new file mode 100644 index 00000000..b8af8333 --- /dev/null +++ b/api/veo_video.py @@ -0,0 +1,80 @@ +""" +Veo Video Generation endpoint — OpenAI-compatible /v1/video/generations. +""" + +from __future__ import annotations + +import json +from typing import Any, Iterator + +from fastapi import Header, HTTPException +from pydantic import BaseModel, ConfigDict + +from services.config import config +from services.image_providers.veo_video import veo_adapter +from utils.log import logger + + +class VideoGenerationRequest(BaseModel): + model_config = ConfigDict(extra="allow") + model: str = "veo/veo-3.1-generate-preview" + prompt: str + n: int = 1 + aspect_ratio: str = "16:9" + duration: str | None = None + resolution: str | None = None + image: str | None = None # base64 image for image→video + last_frame: str | None = None + + +async def handle_video_generation( + body: dict[str, Any], + authorization: str | None = None, +) -> dict[str, Any] | Iterator[dict[str, Any]]: + """Handle POST /v1/video/generations.""" + prompt = str(body.get("prompt") or "") + if not prompt: + raise HTTPException(status_code=400, detail={"error": "prompt is required"}) + + n = max(1, min(1, int(body.get("n") or 1))) # Veo only supports 1 per request + aspect_ratio = str(body.get("aspect_ratio") or "16:9") + duration = body.get("duration") + resolution = body.get("resolution") + image = body.get("image") + last_frame = body.get("last_frame") + + # Get credentials from gemini_free config + providers_cfg = config.data.get("providers") or {} + provider_config = providers_cfg.get("gemini_free") or {} + + credentials = { + "apiKey": str(provider_config.get("api_key") or ""), + "apiKeys": provider_config.get("api_keys") or [], + } + + all_data = [] + for idx in range(n): + try: + result = veo_adapter.generate( + body={ + "prompt": prompt, + "aspect_ratio": aspect_ratio, + "duration": duration, + "resolution": resolution, + "image": image, + "last_frame": last_frame, + }, + credentials=credentials, + ) + all_data.extend(result.get("data") or []) + except Exception as exc: + logger.error({"event": "veo_generation_error", "error": str(exc)}) + raise HTTPException( + status_code=500, + detail={"error": f"Video generation failed: {exc}"}, + ) from exc + + return { + "created": result.get("created", 0) if all_data else 0, + "data": all_data, + } diff --git a/config.json b/config.json index 60536b0f..ee77320d 100644 --- a/config.json +++ b/config.json @@ -46,5 +46,42 @@ "images": false } }, - "image_account_concurrency": 3 + "image_account_concurrency": 3, + "backends": { + "chat": ["chatgpt", "opencode", "gemini_free", "openrouter"], + "image": ["chatgpt", "sdwebui", "huggingface", "cloudflare"], + "default_chat": "auto", + "default_image": "1792x1024" + }, + "providers": { + "opencode": {"enabled": true, "noAuth": true}, + "gemini_free": {"enabled": false, "api_key": ""}, + "openrouter": {"enabled": false, "api_key": ""}, + "sdwebui": {"enabled": false, "base_url": "http://localhost:7860"}, + "huggingface": {"enabled": false, "api_key": ""}, + "cloudflare_ai": {"enabled": false, "account_id": "", "api_token": ""}, + "serper": {"enabled": false, "api_key": ""}, + "searxng": {"enabled": false, "base_url": "http://localhost:8080"}, + "brave": {"enabled": false, "api_key": ""} + }, + "rate_limit": { + "backoff_base_ms": 2000, + "backoff_max_ms": 300000, + "max_levels": 15 + }, + "combo_models": { + "ha-agent": ["oc/auto", "cx/auto", "chatgpt/auto"], + "ha-agent-image": ["sdwebui/stable-diffusion", "chatgpt/gpt-image-2"] + }, + "search": { + "enabled": true, + "backend": "chatgpt", + "auto_detect": true, + "max_results": 3, + "inject_as": "user_message" + }, + "ninerouter": { + "base_url": "http://localhost:20128", + "api_key": "" + } } diff --git a/docker-compose.lite.yml b/docker-compose.lite.yml new file mode 100644 index 00000000..2816fc8d --- /dev/null +++ b/docker-compose.lite.yml @@ -0,0 +1,32 @@ +services: + app: + build: + context: . + dockerfile: Dockerfile.lite + container_name: chatgpt2api-lite + restart: unless-stopped + ports: + - "3030:80" + volumes: + - ./data:/app/data + environment: + # === REQUIRED: Authentication key === + # Used by hass_local_openai_llm to authenticate API calls + - CHATGPT2API_AUTH_KEY=sk-your-secret-key + + # === REQUIRED: ChatGPT tokens (at least one) === + # Multiple accounts supported via numbered env vars + - CHATGPT_TOKEN_1=your_chatgpt_token_here + # - CHATGPT_TOKEN_2=another_token + # - CHATGPT_TOKEN_3=third_token + + # === OPTIONAL: Storage backend === + # Supported: json (default), sqlite, postgres, git + - STORAGE_BACKEND=json + + # === OPTIONAL: Database URL (for sqlite/postgres) === + # - DATABASE_URL=sqlite:///app/data/accounts.db + # - DATABASE_URL=postgresql://user:password@host:5432/dbname + + # === OPTIONAL: Global system prompt === + # - GLOBAL_SYSTEM_PROMPT=You are a helpful assistant. diff --git a/docker-compose.local.yml b/docker-compose.local.yml index 28890190..dbe11e4a 100644 --- a/docker-compose.local.yml +++ b/docker-compose.local.yml @@ -9,7 +9,6 @@ services: - "8000:80" volumes: - ./data:/app/data - - ./config.json:/app/config.json environment: STORAGE_BACKEND: sqlite DATABASE_URL: sqlite:////app/data/accounts.db diff --git a/docker-compose.yml b/docker-compose.yml index a80dc643..6adf0425 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,29 +1,57 @@ +# ============================================================ +# chatgpt2api — Docker Compose cho Portainer +# Tích hợp 9router-chat: OpenCode free, image adapters, search, backup +# ============================================================ +# +# Cách dùng trong Portainer: +# 1. Tạo Stack mới → Paste nội dung file này +# 2. Sửa CHATGPT2API_AUTH_KEY thành key của bạn +# 3. Deploy +# +# Hoặc dùng CLI: +# docker compose up -d +# ============================================================ + services: - app: - image: ghcr.io/basketikun/chatgpt2api:latest + chatgpt2api: + # ── Image từ GitHub Container Registry (khuyên dùng) ── + image: ghcr.io/tritue2011/chatgpt2api:latest + + # ── Hoặc build từ source ── + # build: + # context: . + # dockerfile: Dockerfile + container_name: chatgpt2api restart: unless-stopped + ports: - "3000:80" + volumes: - - ./data:/app/data - - ./config.json:/app/config.json + # Dữ liệu persistent: accounts, config, models cache, ảnh, backup, logs + # Dùng bind mount đến thư mục thật trên host → không bao giờ mất khi build lại + - ./chatgpt2api-data:/app/data + environment: - # 存储后端配置 (可选值: json, sqlite, postgres, git) - - STORAGE_BACKEND=json - - # 数据库配置 (当 STORAGE_BACKEND=sqlite/postgres 时使用) - # - DATABASE_URL=postgresql://user:password@host:5432/dbname - # - DATABASE_URL=sqlite:///app/data/accounts.db - - # Git 仓库配置 (当 STORAGE_BACKEND=git 时使用) - # - GIT_REPO_URL=https://github.com/user/repo.git - # - GIT_TOKEN=ghp_xxxxxxxxxxxx - # - GIT_BRANCH=main - # - GIT_FILE_PATH=accounts.json - - # 认证密钥 (可选,覆盖 config.json) - # - CHATGPT2API_AUTH_KEY=your_secret_key - - # 基础 URL (可选) - # - CHATGPT2API_BASE_URL=https://your-domain.com + # [BẮT BUỘC] Auth key bảo vệ API + CHATGPT2API_AUTH_KEY: your_secret_key_here + + # [Tùy chọn] Public base URL cho link ảnh + # CHATGPT2API_BASE_URL: https://your-domain.com + + # [Tùy chọn] Storage backend: json (mặc định), sqlite, postgres, git + STORAGE_BACKEND: json + + # [Tùy chọn] Database URL nếu dùng sqlite/postgres + # DATABASE_URL: sqlite:///app/data/accounts.db + + # Healthcheck — kiểm tra API có sống không + healthcheck: + test: ["CMD", "python3", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:80/version')"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 15s + +# Không cần khai báo volumes ở cuối nữa — dùng bind mount ./chatgpt2api-data diff --git a/main_lite.py b/main_lite.py new file mode 100644 index 00000000..64b447d0 --- /dev/null +++ b/main_lite.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +import uvicorn +from api.app_lite import create_app + +app = create_app() + +if __name__ == "__main__": + uvicorn.run(app, access_log=False, log_level="info") diff --git a/services/account_service.py b/services/account_service.py index e273bec4..79f64620 100644 --- a/services/account_service.py +++ b/services/account_service.py @@ -10,8 +10,30 @@ LOG_TYPE_ACCOUNT, log_service, ) +from services.rate_limit_backoff import rate_limit_backoff from services.storage.base import StorageBackend from utils.helper import anonymize_token +from utils.log import logger + +# Status migration: Chinese → English (backward compatible) +_STATUS_MIGRATION = { + "正常": "active", + "限流": "limited", + "异常": "error", + "禁用": "disabled", +} +_STATUS_REVERSE = {v: k for k, v in _STATUS_MIGRATION.items()} + +DISPLAY_STATUS = { + "active": "Hoạt động", + "limited": "Giới hạn", + "error": "Lỗi", + "disabled": "Vô hiệu", +} + + +# NoAuth providers — virtual connections (port from 9router FREE_PROVIDERS) +NO_AUTH_PROVIDERS = {"opencode"} class AccountService: @@ -40,7 +62,7 @@ def _save_accounts(self) -> None: def _is_image_account_available(account: dict) -> bool: if not isinstance(account, dict): return False - if account.get("status") in {"禁用", "限流", "异常"}: + if account.get("status") in {"disabled", "limited", "error"}: return False if bool(account.get("image_quota_unknown")): return True @@ -55,7 +77,9 @@ def _normalize_account(self, item: dict) -> dict | None: normalized = dict(item) normalized["access_token"] = access_token normalized["type"] = normalized.get("type") or "free" - normalized["status"] = normalized.get("status") or "正常" + # Auto-migrate Chinese status to English + raw_status = normalized.get("status") or "active" + normalized["status"] = _STATUS_MIGRATION.get(raw_status, raw_status) normalized["quota"] = max(0, int(normalized.get("quota") if normalized.get("quota") is not None else 0)) normalized["image_quota_unknown"] = bool(normalized.get("image_quota_unknown")) normalized["email"] = normalized.get("email") or None @@ -135,7 +159,7 @@ def get_text_access_token(self, excluded_tokens: set[str] | None = None) -> str: candidates = [ token for account in self._accounts.values() - if account.get("status") not in {"禁用", "异常"} + if account.get("status") not in {"disabled", "error"} and (token := account.get("access_token") or "") and token not in excluded ] @@ -162,14 +186,14 @@ def mark_text_used(self, access_token: str) -> None: def remove_invalid_token(self, access_token: str, event: str) -> bool: if not config.auto_remove_invalid_accounts: - self.update_account(access_token, {"status": "异常", "quota": 0}) + self.update_account(access_token, {"status": "error", "quota": 0}) return False removed = bool(self.delete_accounts([access_token])["removed"]) if removed: - log_service.add(LOG_TYPE_ACCOUNT, "自动移除异常账号", + log_service.add(LOG_TYPE_ACCOUNT, "Tự động xóa tài khoản lỗi", {"source": event, "token": anonymize_token(access_token)}) elif access_token: - self.update_account(access_token, {"status": "异常", "quota": 0}) + self.update_account(access_token, {"status": "error", "quota": 0}) return removed def get_account(self, access_token: str) -> dict | None: @@ -188,7 +212,7 @@ def list_limited_tokens(self) -> list[str]: return [ token for item in self._accounts.values() - if item.get("status") == "限流" + if item.get("status") == "limited" and (token := item.get("access_token") or "") ] @@ -218,10 +242,48 @@ def add_accounts(self, tokens: list[str]) -> dict: self._accounts[access_token] = account self._save_accounts() items = [dict(item) for item in self._accounts.values()] - log_service.add(LOG_TYPE_ACCOUNT, f"新增 {added} 个账号,跳过 {skipped} 个", + log_service.add(LOG_TYPE_ACCOUNT, f"新增 {added} tài khoản,跳过 {skipped} 个", {"added": added, "skipped": skipped}) return {"added": added, "skipped": skipped, "items": items} + def add_accounts_with_type(self, tokens: list[str], account_type: str = "codex") -> dict: + """Add accounts with a specific type (e.g. 'codex' for 9router OAuth tokens).""" + tokens = list(dict.fromkeys(token for token in tokens if token)) + if not tokens: + return {"added": 0, "skipped": 0, "items": self.list_accounts()} + + with self._lock: + added = 0 + skipped = 0 + updated = 0 + for access_token in tokens: + current = self._accounts.get(access_token) + if current is not None: + # Merge type: add new type to existing (e.g. existing "free" + new "codex" → "free,codex") + existing_types = set(str(current.get("type") or "").split(",")) + new_types = set(str(account_type).split(",")) + merged = ",".join(sorted(existing_types | new_types)) + if merged != str(current.get("type") or ""): + current["type"] = merged + updated += 1 + logger.info({"event": "account_type_merged", "token": anonymize_token(access_token), "new_type": merged}) + else: + skipped += 1 + continue + added += 1 + account = self._normalize_account({ + "access_token": access_token, + "type": account_type, + "status": "active", + }) + if account is not None: + self._accounts[access_token] = account + self._save_accounts() + items = [dict(item) for item in self._accounts.values()] + log_service.add(LOG_TYPE_ACCOUNT, f"Thêm {added} tài khoản {account_type}, cập nhật {updated}, bỏ qua {skipped}", + {"added": added, "skipped": skipped, "updated": updated, "type": account_type}) + return {"added": added, "skipped": skipped, "updated": updated, "items": items} + def delete_accounts(self, tokens: list[str]) -> dict: target_set = set(token for token in tokens if token) if not target_set: @@ -236,7 +298,7 @@ def delete_accounts(self, tokens: list[str]) -> dict: else: self._index = 0 self._save_accounts() - log_service.add(LOG_TYPE_ACCOUNT, f"删除 {removed} 个账号", {"removed": removed}) + log_service.add(LOG_TYPE_ACCOUNT, f"删除 {removed} tài khoản", {"removed": removed}) items = [dict(item) for item in self._accounts.values()] return {"removed": removed, "items": items} @@ -250,14 +312,14 @@ def update_account(self, access_token: str, updates: dict) -> dict | None: account = self._normalize_account({**current, **updates, "access_token": access_token}) if account is None: return None - if account.get("status") == "限流" and config.auto_remove_rate_limited_accounts: + if account.get("status") == "limited" and config.auto_remove_rate_limited_accounts: self._accounts.pop(access_token, None) self._save_accounts() - log_service.add(LOG_TYPE_ACCOUNT, "自动移除限流账号", {"token": anonymize_token(access_token)}) + log_service.add(LOG_TYPE_ACCOUNT, "Tự động xóa tài khoản giới hạn", {"token": anonymize_token(access_token)}) return None self._accounts[access_token] = account self._save_accounts() - log_service.add(LOG_TYPE_ACCOUNT, "更新账号", + log_service.add(LOG_TYPE_ACCOUNT, "Cập nhật tài khoản", {"token": anonymize_token(access_token), "status": account.get("status")}) return dict(account) return None @@ -278,19 +340,19 @@ def mark_image_result(self, access_token: str, success: bool) -> dict | None: if not image_quota_unknown: next_item["quota"] = max(0, int(next_item.get("quota") or 0) - 1) if not image_quota_unknown and next_item["quota"] == 0: - next_item["status"] = "限流" + next_item["status"] = "limited" next_item["restore_at"] = next_item.get("restore_at") or None - elif next_item.get("status") == "限流": - next_item["status"] = "正常" + elif next_item.get("status") == "limited": + next_item["status"] = "active" else: next_item["fail"] = int(next_item.get("fail") or 0) + 1 account = self._normalize_account(next_item) if account is None: return None - if account.get("status") == "限流" and config.auto_remove_rate_limited_accounts: + if account.get("status") == "limited" and config.auto_remove_rate_limited_accounts: self._accounts.pop(access_token, None) self._save_accounts() - log_service.add(LOG_TYPE_ACCOUNT, "自动移除限流账号", {"token": anonymize_token(access_token)}) + log_service.add(LOG_TYPE_ACCOUNT, "Tự động xóa tài khoản giới hạn", {"token": anonymize_token(access_token)}) return None self._accounts[access_token] = account self._save_accounts() @@ -339,4 +401,111 @@ def refresh_accounts(self, access_tokens: list[str]) -> dict[str, Any]: } + def get_health_score(self, access_token: str) -> float: + """Calculate health score for an account (0.0-1.0). + + Ported from 9router health scoring pattern: + - 0.35: rate-limit status + - 0.20: response latency (placeholder) + - 0.20: concurrency saturation + - 0.15: token last used recency + - 0.10: success/fail ratio + """ + account = self.get_account(access_token) + if not account: + return 0.0 + + score = 0.0 + + # Rate-limit status (0.35) + status = str(account.get("status") or "active") + if status == "active": + score += 0.35 + elif status == "limited": + score += 0.0 + else: + score += 0.1 + + # Concurrency saturation (0.20) + max_conc = max(1, int(config.image_account_concurrency or 1)) + inflight = int(self._image_inflight.get(access_token, 0)) + saturation = inflight / max_conc + score += (1 - saturation) * 0.20 + + # Token recency (0.15) + last_used = account.get("last_used_at") + if last_used: + try: + from datetime import datetime + last_dt = datetime.strptime(str(last_used), "%Y-%m-%d %H:%M:%S") + age_minutes = (datetime.now() - last_dt).total_seconds() / 60 + if age_minutes < 5: + score += 0.15 + elif age_minutes < 30: + score += 0.10 + else: + score += 0.03 + except (ValueError, TypeError): + score += 0.05 + else: + score += 0.05 + + # Success/fail ratio (0.10) + success = int(account.get("success") or 0) + fail = int(account.get("fail") or 0) + total = success + fail + if total > 0: + score += (success / total) * 0.10 + else: + score += 0.05 + + # Latency placeholder (0.20) — default to mid-range + score += 0.10 + + return max(0.0, min(1.0, score)) + + def get_provider_credentials( + self, + provider_id: str, + exclude_connection_ids: set[str] | None = None, + model: str = "", + ) -> dict[str, Any] | None: + """Get credentials for a provider, supporting noAuth virtual connections. + + Ported from 9router src/sse/services/auth.js getProviderCredentials(). + Returns None if no credentials available. + + For noAuth providers (opencode): returns a virtual connection with + id="noauth" and accessToken="public". + """ + # Check for noAuth provider first (port from 9router FREE_PROVIDERS check) + if provider_id in NO_AUTH_PROVIDERS: + return { + "id": "noauth", + "connectionName": "Public", + "isActive": True, + "accessToken": "public", + "noAuth": True, + } + + # For chatgpt provider, use existing token pool + if provider_id == "chatgpt": + token = self.get_text_access_token(exclude_connection_ids) + if not token: + return None + return { + "id": anonymize_token(token), + "connectionName": "ChatGPT", + "isActive": True, + "accessToken": token, + "noAuth": False, + } + + return None + + def is_noauth_provider(self, provider_id: str) -> bool: + """Check if a provider uses noAuth virtual connections.""" + return provider_id in NO_AUTH_PROVIDERS + + account_service = AccountService(config.get_storage_backend()) diff --git a/services/auth_service.py b/services/auth_service.py index ef9a29bf..21b6ac1a 100644 --- a/services/auth_service.py +++ b/services/auth_service.py @@ -35,7 +35,7 @@ def _clean(value: object) -> str: @staticmethod def _default_name(role: object) -> str: - return "管理员密钥" if str(role or "").strip().lower() == "admin" else "普通用户" + return "Khóa quản trị" if str(role or "").strip().lower() == "admin" else "Người dùng" def _normalize_item(self, raw: object) -> dict[str, object] | None: if not isinstance(raw, dict): @@ -105,13 +105,13 @@ def _has_key_hash_locked(self, key_hash: str, *, exclude_id: str = "") -> bool: def _build_key_hash_locked(self, raw_key: str, *, exclude_id: str = "") -> str: candidate = self._clean(raw_key) if not candidate: - raise ValueError("请输入新的专用密钥") + raise ValueError("Vui lòng nhập khóa mới") admin_key = self._clean(config.auth_key) if admin_key and hmac.compare_digest(candidate, admin_key): - raise ValueError("这个密钥和管理员密钥冲突了,请换一个新的密钥") + raise ValueError("这个密钥和Khóa quản trị冲突了,请换一个新的密钥") key_hash = _hash_key(candidate) if self._has_key_hash_locked(key_hash, exclude_id=exclude_id): - raise ValueError("这个专用密钥已经存在,请换一个新的密钥") + raise ValueError("Khóa này đã tồn tại, vui lòng dùng khóa khác") return key_hash def _has_name_locked(self, name: str, *, role: AuthRole | None = None, exclude_id: str = "") -> bool: @@ -144,7 +144,7 @@ def _build_name_locked(self, name: str, *, role: AuthRole, exclude_id: str = "") if not candidate: return self._build_default_name_locked(role, exclude_id=exclude_id) if self._has_name_locked(candidate, role=role, exclude_id=exclude_id): - raise ValueError("这个名称已经在使用中了,换一个更容易区分的名称吧") + raise ValueError("Tên này đã được dùng, hãy chọn tên khác") return candidate def create_key(self, *, role: AuthRole, name: str = "") -> tuple[dict[str, object], str]: diff --git a/services/backend_router.py b/services/backend_router.py new file mode 100644 index 00000000..15c05894 --- /dev/null +++ b/services/backend_router.py @@ -0,0 +1,282 @@ +""" +BackendRouter — route requests to the appropriate AI backend. + +Port pattern from 9router getProviderCredentials() + combo model routing: +- Model prefix determines provider: oc/ → OpenCode, gw/ → Grok Web, etc. +- Payload > 24KB → ưu tiên provider không giới hạn (opencode, gemini, openrouter) +- Payload ≤ 24KB → dùng ChatGPT free +- Image models stay on ChatGPT DALL-E path +- Combo models fallback qua nhiều provider +""" + +from __future__ import annotations + +import json +from typing import Any + +from services.config import config +from utils.helper import IMAGE_MODELS + +# Provider prefixes ported from 9router src/shared/constants/providers.js +PROVIDER_PREFIXES: dict[str, str] = { + "9r/": "ninerouter", + "cx/": "openai_oauth", + "codex/": "openai_oauth", + "oc/": "opencode", + "ocg/": "opencode_go", + "gemini_free/": "gemini_free", + "gemini/": "gemini_free", + "gw/": "grok_web", + "pw/": "perplexity_web", + "gc/": "gemini_cli", + "kr/": "kiro", + "qw/": "qwen", + "if/": "iflow", + "gh/": "github", + "cu/": "cursor", + "cc/": "claude", + "cx/": "codex", + "nv/": "nvidia_nim", +} + +# NoAuth providers — no credentials needed (port from 9router FREE_PROVIDERS) +NO_AUTH_PROVIDERS: set[str] = {"opencode"} + +# Providers that accept API key (not OAuth) +API_KEY_PROVIDERS: set[str] = { + "gemini_free", + "openrouter", + "deepseek", + "groq", + "xai", + "mistral", + "perplexity", + "together", + "nvidia_nim", +} + +# Image providers from 9router image adapter system +IMAGE_PROVIDER_PREFIXES: dict[str, str] = { + "sdwebui/": "sdwebui", + "comfyui/": "comfyui", + "huggingface/": "huggingface", + "fal-ai/": "fal_ai", + "stability/": "stability_ai", + "bfl/": "black_forest_labs", + "cloudflare/": "cloudflare_ai", + "recraft/": "recraft", + "runwayml/": "runwayml", + "nv-image/": "nvidia_nim_image", + "gemini-image/": "gemini", +} + + +class BackendRoute: + """Result of routing decision.""" + def __init__( + self, + provider: str, + model: str, + no_auth: bool = False, + api_key: str = "", + base_url: str = "", + is_image: bool = False, + fallback_providers: list[str] | None = None, + ): + self.provider = provider + self.model = model + self.no_auth = no_auth + self.api_key = api_key + self.base_url = base_url + self.is_image = is_image + self.fallback_providers = fallback_providers or [] + + +class BackendRouter: + """ + Route request đến backend phù hợp nhất: + - Payload > 24KB → ưu tiên provider không giới hạn (opencode, gemini, openrouter) + - Payload ≤ 24KB → có thể dùng ChatGPT free + - Model có prefix oc/ → OpenCode, gw/ → Grok Web, v.v. + - Combo model → fallback qua nhiều provider + """ + + # Payload threshold for free ChatGPT accounts (24KB) + FREE_PAYLOAD_LIMIT = 24_000 + + # Default model per provider (for "auto" resolution) + PROVIDER_DEFAULT_MODELS: dict[str, str] = { + "ninerouter": "auto", + "openai_oauth": "gpt-5.3-codex", + "opencode": "nemotron-3-super-free", + "chatgpt": "auto", + "gemini_free": "gemini-3-flash-preview", + "openrouter": "openai/gpt-4o", + "nvidia_nim": "openai/gpt-oss-120b", + } + + @staticmethod + def resolve_model(model_str: str) -> tuple[str, str]: + """Parse model string → (provider, model_name). + + Examples: + "gpt-4" → ("chatgpt", "gpt-4") + "oc/nemotron-free" → ("opencode", "nemotron-free") + "sdwebui/sd-v1.5" → ("sdwebui", "sd-v1.5") + "huggingface/black-forest-labs/FLUX.1-schnell" → ("huggingface", "black-forest-labs/FLUX.1-schnell") + """ + model_str = str(model_str or "").strip() + + # Check image provider prefixes first + for prefix, provider in IMAGE_PROVIDER_PREFIXES.items(): + if model_str.startswith(prefix): + return (provider, model_str[len(prefix):]) + + # Check chat provider prefixes + for prefix, provider in PROVIDER_PREFIXES.items(): + if model_str.startswith(prefix): + return (provider, model_str[len(prefix):]) + + # Check custom providers (dynamic, configured via UI) + from services.providers.custom_openai import resolve_custom_provider + custom_cfg, custom_rest = resolve_custom_provider(model_str) + if custom_cfg is not None: + provider_id = str(custom_cfg.get("prefix") or "") + return (f"custom:{provider_id}", custom_rest) + + # Default: use ChatGPT + return ("chatgpt", model_str) + + @staticmethod + def is_image_model(model_str: str) -> bool: + """Check if model is an image generation model.""" + model_str = str(model_str or "").strip() + if model_str in IMAGE_MODELS: + return True + for prefix in IMAGE_PROVIDER_PREFIXES: + if model_str.startswith(prefix): + return True + return False + + @staticmethod + def get_payload_size(messages: list[dict[str, Any]]) -> int: + """Calculate JSON payload size in bytes.""" + try: + payload = json.dumps(messages, ensure_ascii=False, default=str) + return len(payload.encode("utf-8")) + except Exception: + return 0 + + def route( + self, + model: str, + messages: list[dict[str, Any]] | None = None, + payload_size: int | None = None, + ) -> BackendRoute: + """Determine the best backend for a request. + + Args: + model: Model string from request + messages: Normalized messages (for payload size calculation) + payload_size: Pre-calculated payload size in bytes (optional) + + Returns: + BackendRoute with provider, model, auth info + """ + provider, resolved_model = self.resolve_model(model) + is_image = self.is_image_model(model) + + # Resolve "auto" to provider's default model (check user config first) + if resolved_model == "auto" or not resolved_model: + provider_cfg = (config.data.get("providers") or {}).get(provider) or {} + user_model = str(provider_cfg.get("model") or "").strip() + resolved_model = user_model or self.PROVIDER_DEFAULT_MODELS.get(provider, "auto") + + # Calculate payload size if not provided + if payload_size is None and messages: + payload_size = self.get_payload_size(messages) + + # Image models always use their configured provider + if is_image: + if provider == "chatgpt": + # ChatGPT DALL-E — existing path + return BackendRoute( + provider="chatgpt", + model=model, + is_image=True, + ) + else: + # External image provider (sdwebui, huggingface, etc.) + return BackendRoute( + provider=provider, + model=resolved_model, + no_auth=provider in NO_AUTH_PROVIDERS, + is_image=True, + ) + + # Text chat routing + if provider == "chatgpt": + # If payload is large and we have free providers, suggest fallback + if payload_size and payload_size > self.FREE_PAYLOAD_LIMIT: + # Check if OpenCode is enabled in config + opencode_config = (config.data.get("providers") or {}).get("opencode") or {} + if opencode_config.get("enabled", True): + return BackendRoute( + provider="opencode", + model=model if model != "auto" else "auto", + no_auth=True, + fallback_providers=["chatgpt"], + ) + + # Use ChatGPT as normal + return BackendRoute( + provider="chatgpt", + model=model, + ) + + # Non-ChatGPT provider (opencode, gemini_free, etc.) + provider_config = (config.data.get("providers") or {}).get(provider) or {} + return BackendRoute( + provider=provider, + model=resolved_model or model, + no_auth=provider in NO_AUTH_PROVIDERS, + api_key=str(provider_config.get("api_key") or ""), + base_url=str(provider_config.get("base_url") or ""), + fallback_providers=["chatgpt"], + ) + + def route_combo(self, combo_name: str) -> list[BackendRoute]: + """Resolve a combo model into its fallback chain (case-insensitive).""" + models = self._get_combo_models(combo_name) + if not models: + return [] + + routes: list[BackendRoute] = [] + for model_str in models: + route = self.route(str(model_str)) + routes.append(route) + + return routes + + def is_combo(self, model_str: str) -> bool: + """Check if a model string is a combo name (case-insensitive).""" + combos = config.data.get("combo_models") or {} + if not isinstance(combos, dict): + return False + model_lower = model_str.lower().strip() + return any(k.lower().strip() == model_lower for k in combos) + + def _get_combo_models(self, combo_name: str) -> list[str] | None: + """Get combo model list by name (case-insensitive).""" + combos = config.data.get("combo_models") or {} + if not isinstance(combos, dict): + return None + name_lower = combo_name.lower().strip() + for k, v in combos.items(): + if k.lower().strip() == name_lower and isinstance(v, list): + return v + return None + + +# Singleton +backend_router = BackendRouter() diff --git a/services/config.py b/services/config.py index 74b46bb7..30082fcd 100644 --- a/services/config.py +++ b/services/config.py @@ -12,6 +12,7 @@ BASE_DIR = Path(__file__).resolve().parents[1] DATA_DIR = BASE_DIR / "data" CONFIG_FILE = BASE_DIR / "config.json" +CONFIG_DATA_FILE = DATA_DIR / "config.json" VERSION_FILE = BASE_DIR / "VERSION" BACKUP_STATE_FILE = DATA_DIR / "backup_state.json" @@ -119,6 +120,12 @@ def _load_settings() -> LoadedSettings: DATA_DIR.mkdir(parents=True, exist_ok=True) raw_config = _read_json_object(CONFIG_FILE, name="config.json") auth_key = _normalize_auth_key(os.getenv("CHATGPT2API_AUTH_KEY") or raw_config.get("auth-key")) + + # HA addon fallback: read from /data/options.json if auth_key still empty + if _is_invalid_auth_key(auth_key): + addon_options = _read_json_object(Path("/data/options.json"), name="HA addon options") + auth_key = _normalize_auth_key(addon_options.get("auth_key") or "") + if _is_invalid_auth_key(auth_key): raise ValueError( "❌ auth-key 未设置!\n" @@ -153,14 +160,28 @@ def __init__(self, path: Path): ) def _load(self) -> dict[str, object]: + # Load from data dir first (persists across restarts), fallback to root + if CONFIG_DATA_FILE.exists(): + return _read_json_object(CONFIG_DATA_FILE, name="data/config.json") return _read_json_object(self.path, name="config.json") def _save(self) -> None: - self.path.write_text(json.dumps(self.data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") + DATA_DIR.mkdir(parents=True, exist_ok=True) + CONFIG_DATA_FILE.write_text(json.dumps(self.data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") + # Also sync to root if different (backward compat) + if self.path != CONFIG_DATA_FILE: + self.path.write_text(json.dumps(self.data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") @property def auth_key(self) -> str: - return _normalize_auth_key(os.getenv("CHATGPT2API_AUTH_KEY") or self.data.get("auth-key")) + # Priority: 1) ENV var 2) HA addon config 3) config.json + key = _normalize_auth_key(os.getenv("CHATGPT2API_AUTH_KEY")) + if _is_invalid_auth_key(key): + addon_options = _read_json_object(Path("/data/options.json"), name="HA addon options") + key = _normalize_auth_key(addon_options.get("auth_key") or "") + if _is_invalid_auth_key(key): + key = _normalize_auth_key(str(self.data.get("auth-key") or "")) + return key @property def accounts_file(self) -> Path: @@ -230,6 +251,22 @@ def ai_review(self) -> dict[str, object]: def global_system_prompt(self) -> str: return str(self.data.get("global_system_prompt") or "").strip() + @property + def karpathy_mode(self) -> bool: + return _normalize_bool(self.data.get("karpathy_mode"), False) + + @property + def auto_refresh_enabled(self) -> bool: + return _normalize_bool(self.data.get("auto_refresh_enabled"), True) + + @property + def default_image_size(self) -> str: + size = str(self.data.get("default_image_size") or "1792x1024").strip() + # Validate: must be WxH format + if "x" in size: + return size + return "1792x1024" + @property def images_dir(self) -> Path: path = DATA_DIR / "images" @@ -258,11 +295,16 @@ def cleanup_old_images(self) -> int: @property def base_url(self) -> str: - return str( + url = str( os.getenv("CHATGPT2API_BASE_URL") or self.data.get("base_url") or "" ).strip().rstrip("/") + # HA addon fallback + if not url: + addon_options = _read_json_object(Path("/data/options.json"), name="HA addon options") + url = str(addon_options.get("base_url") or "").strip().rstrip("/") + return url @property def app_version(self) -> str: @@ -299,6 +341,12 @@ def update(self, data: dict[str, object]) -> dict[str, object]: next_data.pop("backup_state", None) self.data = next_data self._save() + # Invalidate model cache when settings change (combo_models, providers, etc.) + try: + from services.protocol.openai_v1_models import invalidate_models_cache + invalidate_models_cache() + except Exception: + pass return self.get() def get_backup_settings(self) -> dict[str, object]: diff --git a/services/image_providers/__init__.py b/services/image_providers/__init__.py new file mode 100644 index 00000000..95de04a6 --- /dev/null +++ b/services/image_providers/__init__.py @@ -0,0 +1,61 @@ +""" +Image Adapter Registry — port from 9router open-sse/handlers/imageProviders/index.js. + +Maps provider keys to adapter instances. +""" + +from __future__ import annotations + +from typing import Any + +from services.image_providers._base import BaseImageAdapter +from services.image_providers.sdwebui import SDWebUIAdapter +from services.image_providers.huggingface import HuggingFaceAdapter +from services.image_providers.cloudflare import CloudflareAIAdapter +from services.image_providers.fal_ai import FalAIAdapter +from services.image_providers.stability import StabilityAIAdapter +from services.image_providers.bfl import BFLAdapter +from services.image_providers.gemini_image import GeminiImageAdapter +from services.image_providers.nvidia_nim_image import NvidiaNimImageAdapter + + +# Registry — matches 9router ADAPTERS mapping +IMAGE_ADAPTERS: dict[str, BaseImageAdapter] = { + "sdwebui": SDWebUIAdapter(), + "huggingface": HuggingFaceAdapter(), + "cloudflare_ai": CloudflareAIAdapter(), + "fal_ai": FalAIAdapter(), + "stability_ai": StabilityAIAdapter(), + "black_forest_labs": BFLAdapter(), + "gemini": GeminiImageAdapter(), + "nvidia_nim_image": NvidiaNimImageAdapter(), +} + + +def get_image_adapter(provider: str) -> BaseImageAdapter | None: + """Look up an image adapter by provider key. + + For custom providers (custom: prefix), returns a generic adapter + that uses the chat completions endpoint for image generation. + """ + if provider in IMAGE_ADAPTERS: + return IMAGE_ADAPTERS[provider] + if provider.startswith("custom:"): + from services.image_providers.custom_openai_image import CustomOpenAIImageAdapter + cp_id = provider[len("custom:"):] + return CustomOpenAIImageAdapter(cp_id) + return None + + +def is_image_provider(provider: str) -> bool: + """Check if a provider has an image adapter registered.""" + return provider in IMAGE_ADAPTERS + + +# NoAuth image providers (no API key needed) +NO_AUTH_IMAGE_PROVIDERS: set[str] = {"sdwebui"} + + +def is_noauth_image_provider(provider: str) -> bool: + """Check if an image provider requires no authentication.""" + return provider in NO_AUTH_IMAGE_PROVIDERS diff --git a/services/image_providers/_base.py b/services/image_providers/_base.py new file mode 100644 index 00000000..4b3cb64d --- /dev/null +++ b/services/image_providers/_base.py @@ -0,0 +1,122 @@ +""" +Image Provider Adapters — port from 9router open-sse/handlers/imageProviders/. + +Base utilities shared across all image adapters: +- POLL_INTERVAL_MS / POLL_TIMEOUT_MS for async adapters +- size_to_aspect_ratio: convert OpenAI size string to width/height +- url_to_base64: download image URL and convert to base64 +- Size constants and default (16:9) +""" + +from __future__ import annotations + +import base64 +import time +from typing import Any + +from curl_cffi import requests + +# Polling config (port from 9router _base.js) +POLL_INTERVAL_S = 1.5 +POLL_TIMEOUT_S = 120 + +# OpenAI size → width x height (16:9 mặc định) +SIZE_MAP: dict[str, tuple[int, int]] = { + "1024x1024": (1024, 1024), # 1:1 + "1792x1024": (1792, 1024), # 16:9 ← DEFAULT + "1024x1792": (1024, 1792), # 9:16 + "1280x896": (1280, 896), # ~4:3 landscape + "896x1280": (896, 1280), # ~4:3 portrait +} + +DEFAULT_SIZE = "1792x1024" + + +def size_to_width_height(size: str | None) -> tuple[int, int]: + """Convert OpenAI size string → (width, height). Default 16:9.""" + if not size: + return SIZE_MAP[DEFAULT_SIZE] + if size in SIZE_MAP: + return SIZE_MAP[size] + # Try parsing "WxH" format + try: + parts = size.split("x") + if len(parts) == 2: + return (int(parts[0]), int(parts[1])) + except (ValueError, TypeError): + pass + return SIZE_MAP[DEFAULT_SIZE] + + +def size_to_aspect_ratio(size: str | None) -> str: + """Convert OpenAI size → aspect ratio string (e.g. '16:9').""" + w, h = size_to_width_height(size) + if w == h: + return "1:1" + if w > h: + if w / h > 1.6: + return "16:9" + return "4:3" + else: + if h / w > 1.6: + return "9:16" + return "3:4" + + +def url_to_base64(url: str, timeout: int = 30) -> str: + """Download image from URL and return as base64 string.""" + resp = requests.get(url, timeout=timeout) + resp.raise_for_status() + return base64.b64encode(resp.content).decode("ascii") + + +def now_sec() -> int: + """Current time in seconds (Unix timestamp).""" + return int(time.time()) + + +def sleep_s(seconds: float) -> None: + """Sleep for seconds.""" + time.sleep(seconds) + + +class BaseImageAdapter: + """Base class for image generation adapters. + + Ported from 9router imageProviders adapters. + Each adapter implements: + - build_url(model, credentials) -> str + - build_body(model, body) -> dict + - build_headers(credentials, request_body, model, body) -> dict + - normalize(parsed) -> dict (OpenAI-compatible format) + - parse_response(response) -> dict | None (optional, for async/polling) + """ + + no_auth: bool = False + + def build_url(self, model: str, credentials: dict[str, Any] | None) -> str: + raise NotImplementedError + + def build_body(self, model: str, body: dict[str, Any]) -> dict[str, Any]: + raise NotImplementedError + + def build_headers( + self, + credentials: dict[str, Any] | None, + request_body: dict[str, Any], + model: str, + body: dict[str, Any], + ) -> dict[str, str]: + raise NotImplementedError + + def parse_response(self, response: Any) -> dict[str, Any] | None: + """Optional: custom response parsing (async polling, SSE, etc.).""" + return None + + def normalize(self, parsed: dict[str, Any], body: dict[str, Any]) -> dict[str, Any]: + """Convert provider response to OpenAI format: {created, data: [{b64_json}]}.""" + raise NotImplementedError + + def test_connection(self, credentials: dict[str, Any] | None = None) -> bool: + """Quick connection test.""" + return True diff --git a/services/image_providers/bfl.py b/services/image_providers/bfl.py new file mode 100644 index 00000000..6545a9b6 --- /dev/null +++ b/services/image_providers/bfl.py @@ -0,0 +1,125 @@ +""" +Black Forest Labs Adapter — port from 9router blackForestLabs.js. + +BFL/FLUX API: https://api.bfl.ai/v1 +Async polling-based adapter. +""" + +from __future__ import annotations + +import base64 +from typing import Any + +from curl_cffi import requests + +from services.image_providers._base import ( + BaseImageAdapter, + POLL_INTERVAL_S, + POLL_TIMEOUT_S, + now_sec, + sleep_s, +) +from utils.log import logger + + +class BFLAdapter(BaseImageAdapter): + """Black Forest Labs (FLUX) adapter — async polling. + + Model format: bfl/flux-pro-1.1, bfl/flux-dev, bfl/flux-schnell + """ + + BASE_URL = "https://api.bfl.ai/v1" + + def build_url(self, model: str, credentials: dict[str, Any] | None) -> str: + # Map model to endpoint + if "pro" in model: + return f"{self.BASE_URL}/flux-pro-1.1" + elif "dev" in model: + return f"{self.BASE_URL}/flux-dev" + else: + return f"{self.BASE_URL}/{model}" + + def build_body(self, model: str, body: dict[str, Any]) -> dict[str, Any]: + prompt = str(body.get("prompt") or "") + size = str(body.get("size") or "1792x1024") + + # BFL uses width/height + from services.image_providers._base import size_to_width_height + w, h = size_to_width_height(size) + + return { + "prompt": prompt, + "width": w, + "height": h, + } + + def build_headers( + self, + credentials: dict[str, Any] | None, + request_body: dict[str, Any], + model: str, + body: dict[str, Any], + ) -> dict[str, str]: + api_key = "" + if credentials and isinstance(credentials, dict): + api_key = str(credentials.get("apiKey") or credentials.get("accessToken") or "") + return { + "X-Key": api_key, + "Content-Type": "application/json", + } + + def parse_response(self, response: Any) -> dict[str, Any] | None: + """Submit and poll for result.""" + if not hasattr(response, "json"): + return None + + data = response.json() + task_id = data.get("id") + if not task_id: + return None + + # Poll for result + elapsed = 0.0 + while elapsed < POLL_TIMEOUT_S: + sleep_s(POLL_INTERVAL_S) + elapsed += POLL_INTERVAL_S + + try: + poll_resp = requests.get( + f"{self.BASE_URL}/get_result", + params={"id": task_id}, + timeout=30, + ) + poll_data = poll_resp.json() + except Exception as exc: + logger.warning({"event": "bfl_poll_error", "error": str(exc)}) + continue + + status = poll_data.get("status", "") + if status == "Ready": + result = poll_data.get("result", {}) + sample_url = result.get("sample") + if sample_url: + from services.image_providers._base import url_to_base64 + return {"data": [{"b64_json": url_to_base64(sample_url)}]} + + elif status in ("Error", "Failed"): + logger.error({"event": "bfl_failed", "status": status}) + return None + + logger.error({"event": "bfl_timeout", "elapsed": elapsed}) + return None + + def normalize(self, parsed: dict[str, Any], body: dict[str, Any]) -> dict[str, Any]: + data = parsed.get("data") or [] + return {"created": now_sec(), "data": data} + + def test_connection(self, credentials: dict[str, Any] | None = None) -> bool: + try: + resp = requests.get("https://api.bfl.ai/v1", timeout=10) + return resp.status_code < 500 + except Exception: + return False + + +bfl_adapter = BFLAdapter() diff --git a/services/image_providers/cloudflare.py b/services/image_providers/cloudflare.py new file mode 100644 index 00000000..94e9e2e5 --- /dev/null +++ b/services/image_providers/cloudflare.py @@ -0,0 +1,80 @@ +""" +Cloudflare Workers AI Adapter — port from 9router cloudflareAi.js. + +Free tier available with Cloudflare account. +Models: @cf/black-forest-labs/flux-1-schnell, @cf/bytedance/stable-diffusion-xl-lightning +""" + +from __future__ import annotations + +import base64 +from typing import Any + +from curl_cffi import requests + +from services.image_providers._base import BaseImageAdapter, now_sec +from utils.log import logger + + +class CloudflareAIAdapter(BaseImageAdapter): + """Cloudflare Workers AI adapter. + + Requires Cloudflare Account ID + API Token (free tier available). + Model format: cloudflare/@cf/black-forest-labs/flux-1-schnell + """ + + BASE_URL = "https://api.cloudflare.com/client/v4/accounts" + + def build_url(self, model: str, credentials: dict[str, Any] | None) -> str: + account_id = "" + if credentials and isinstance(credentials, dict): + account_id = str(credentials.get("accountId") or credentials.get("account_id") or "") + return f"{self.BASE_URL}/{account_id}/ai/run/{model}" + + def build_body(self, model: str, body: dict[str, Any]) -> dict[str, Any]: + prompt = str(body.get("prompt") or "") + n = max(1, min(4, int(body.get("n") or 1))) + return { + "prompt": prompt, + "num_steps": 4 if "schnell" in model else 8, + } + + def build_headers( + self, + credentials: dict[str, Any] | None, + request_body: dict[str, Any], + model: str, + body: dict[str, Any], + ) -> dict[str, str]: + api_token = "" + if credentials and isinstance(credentials, dict): + api_token = str(credentials.get("apiToken") or credentials.get("api_token") or credentials.get("accessToken") or "") + return { + "Authorization": f"Bearer {api_token}", + "Content-Type": "application/json", + } + + def parse_response(self, response: Any) -> dict[str, Any] | None: + # Cloudflare returns {"result": {"image": "base64..."}} + if hasattr(response, "json"): + data = response.json() + result = data.get("result", {}) + if isinstance(result, dict) and result.get("image"): + return {"image_base64": result["image"]} + return None + + def normalize(self, parsed: dict[str, Any], body: dict[str, Any]) -> dict[str, Any]: + b64 = parsed.get("image_base64") + if b64 and isinstance(b64, str): + return {"created": now_sec(), "data": [{"b64_json": b64}]} + return {"created": now_sec(), "data": []} + + def test_connection(self, credentials: dict[str, Any] | None = None) -> bool: + try: + resp = requests.get("https://api.cloudflare.com/client/v4/user/tokens/verify", timeout=10) + return resp.status_code < 500 + except Exception: + return False + + +cloudflare_adapter = CloudflareAIAdapter() diff --git a/services/image_providers/custom_openai_image.py b/services/image_providers/custom_openai_image.py new file mode 100644 index 00000000..7b89574e --- /dev/null +++ b/services/image_providers/custom_openai_image.py @@ -0,0 +1,145 @@ +""" +Custom OpenAI-compatible Image Adapter — uses chat endpoint for image generation. + +For custom providers that support image gen via their chat API (e.g., Gemini API +server with /v1/responses or built-in image generation tools). +""" + +from __future__ import annotations + +import base64 +from typing import Any + +from curl_cffi import requests + +from services.image_providers._base import ( + BaseImageAdapter, + now_sec, + size_to_width_height, +) +from services.config import config +from utils.log import logger + + +class CustomOpenAIImageAdapter(BaseImageAdapter): + """Generic image adapter for custom providers — uses chat endpoint.""" + + def __init__(self, provider_id: str): + self.provider_id = provider_id + + def _get_provider_config(self) -> dict[str, Any] | None: + providers = config.data.get("custom_providers") or {} + return providers.get(self.provider_id) + + def build_url(self, model: str, credentials: dict[str, Any] | None) -> str: + cfg = self._get_provider_config() + base_url = str(cfg.get("base_url") or "").rstrip("/") if cfg else "" + return f"{base_url}/v1/chat/completions" + + def build_body(self, model: str, body: dict[str, Any]) -> dict[str, Any]: + prompt = str(body.get("prompt") or "") + size = str(body.get("size") or "1792x1024") + w, h = size_to_width_height(size) + + return { + "model": model, + "messages": [{ + "role": "user", + "content": ( + f"Generate an image based on this description: {prompt}\n" + f"Size: {w}x{h}\n" + f"Return the image as a base64 data URL." + ), + }], + "max_tokens": 4096, + "temperature": 0.9, + } + + def build_headers( + self, + credentials: dict[str, Any] | None, + request_body: dict[str, Any], + model: str, + body: dict[str, Any], + ) -> dict[str, str]: + cfg = self._get_provider_config() + api_key = "" + if cfg: + keys = cfg.get("api_keys") or [] + if not keys: + api_key = str(cfg.get("api_key") or "") + else: + api_key = keys[0] + return { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + + def parse_response(self, response: Any) -> dict[str, Any] | None: + """Parse chat response to extract generated image (base64).""" + if not hasattr(response, "json"): + return None + + try: + data = response.json() + except Exception as exc: + logger.error({"event": "custom_image_parse_error", "error": str(exc)}) + return None + + choices = data.get("choices") or [] + for choice in choices: + content = choice.get("message", {}).get("content") or "" + if not content: + continue + + # Try to extract base64 image from response + import re + # Match data:image/...;base64,... + match = re.search(r'data:image/[^;]+;base64,([A-Za-z0-9+/=]+)', content) + if match: + return {"data": [{"b64_json": match.group(1)}]} + + # Maybe the entire response is a base64 image + if content.startswith('/9j/') or content.startswith('iVBOR'): + return {"data": [{"b64_json": content}]} + + # If the response contains a URL to an image + url_match = re.search(r'https?://[^\s"]+\.(?:png|jpg|jpeg|webp)[^\s"]*', content) + if url_match: + from services.image_providers._base import url_to_base64 + try: + b64 = url_to_base64(url_match.group(0)) + return {"data": [{"b64_json": b64}]} + except Exception: + pass + + logger.warning({"event": "custom_image_no_data", "provider": self.provider_id}) + return None + + def normalize(self, parsed: dict[str, Any], body: dict[str, Any]) -> dict[str, Any]: + data = parsed.get("data") or [] + normalized_data = [] + for item in data: + b64 = item.get("b64_json") or "" + if b64 and not b64.startswith("data:"): + b64 = f"data:image/png;base64,{b64}" + if b64: + normalized_data.append({"b64_json": b64, "revised_prompt": str(body.get("prompt") or "")}) + return {"created": now_sec(), "data": normalized_data} if normalized_data else {"created": now_sec(), "data": []} + + def test_connection(self, credentials: dict[str, Any] | None = None) -> bool: + cfg = self._get_provider_config() + if not cfg: + return False + base_url = str(cfg.get("base_url") or "").rstrip("/") + keys = cfg.get("api_keys") or [str(cfg.get("api_key") or "")] + api_key = keys[0] if keys else "" + try: + resp = requests.get( + f"{base_url}/v1/models", + headers={"Authorization": f"Bearer {api_key}"}, + timeout=10, + ) + return resp.status_code == 200 + except Exception: + return False diff --git a/services/image_providers/fal_ai.py b/services/image_providers/fal_ai.py new file mode 100644 index 00000000..74c5c74f --- /dev/null +++ b/services/image_providers/fal_ai.py @@ -0,0 +1,133 @@ +""" +Fal.ai Adapter — port from 9router falAi.js. + +Async (polling-based) adapter for Fal.ai queue API. +Model format: fal_ai/fal-ai/flux/schnell +""" + +from __future__ import annotations + +import base64 +from typing import Any + +from curl_cffi import requests + +from services.image_providers._base import ( + BaseImageAdapter, + POLL_INTERVAL_S, + POLL_TIMEOUT_S, + now_sec, + sleep_s, +) +from utils.log import logger + + +class FalAIAdapter(BaseImageAdapter): + """Fal.ai async queue adapter. + + Submits to queue API, polls status_url until completion. + """ + + BASE_URL = "https://queue.fal.run" + + def build_url(self, model: str, credentials: dict[str, Any] | None) -> str: + return f"{self.BASE_URL}/{model}" + + def build_body(self, model: str, body: dict[str, Any]) -> dict[str, Any]: + prompt = str(body.get("prompt") or "") + size = str(body.get("size") or "1792x1024") + + # Map OpenAI size to fal.ai image_size + size_map = { + "1024x1024": "square_hd", + "1792x1024": "landscape_16_9", + "1024x1792": "portrait_16_9", + "1280x896": "landscape_4_3", + "896x1280": "portrait_4_3", + } + + return { + "prompt": prompt, + "image_size": size_map.get(size, "landscape_16_9"), + "num_images": max(1, min(4, int(body.get("n") or 1))), + } + + def build_headers( + self, + credentials: dict[str, Any] | None, + request_body: dict[str, Any], + model: str, + body: dict[str, Any], + ) -> dict[str, str]: + api_key = "" + if credentials and isinstance(credentials, dict): + api_key = str(credentials.get("apiKey") or credentials.get("accessToken") or "") + return { + "Authorization": f"Key {api_key}", + "Content-Type": "application/json", + } + + def parse_response(self, response: Any) -> dict[str, Any] | None: + """Submit to queue, poll for result.""" + if not hasattr(response, "json"): + return None + + data = response.json() + status_url = data.get("status_url") + if not status_url: + logger.error({"event": "fal_ai_no_status_url", "response": str(data)[:200]}) + return None + + # Poll for completion + elapsed = 0.0 + while elapsed < POLL_TIMEOUT_S: + sleep_s(POLL_INTERVAL_S) + elapsed += POLL_INTERVAL_S + + try: + poll_resp = requests.get(status_url, timeout=30) + poll_data = poll_resp.json() + except Exception as exc: + logger.warning({"event": "fal_ai_poll_error", "error": str(exc)}) + continue + + status = poll_data.get("status", "") + if status == "COMPLETED": + result = poll_data.get("result") or poll_data + images = ( + result.get("images") or + [result.get("image")] if result.get("image") else + [] + ) + if images: + b64_list = [] + for img in images: + if isinstance(img, dict) and img.get("url"): + from services.image_providers._base import url_to_base64 + b64_list.append({"b64_json": url_to_base64(img["url"])}) + return {"data": b64_list} + + logger.error({"event": "fal_ai_no_images", "result": str(result)[:200]}) + return None + + elif status in ("FAILED", "CANCELLED"): + error_msg = str(poll_data.get("error") or "unknown") + logger.error({"event": "fal_ai_failed", "status": status, "error": error_msg}) + return None + + logger.error({"event": "fal_ai_timeout", "elapsed": elapsed}) + return None + + def normalize(self, parsed: dict[str, Any], body: dict[str, Any]) -> dict[str, Any]: + data = parsed.get("data") or [] + return {"created": now_sec(), "data": data} + + def test_connection(self, credentials: dict[str, Any] | None = None) -> bool: + try: + resp = requests.get("https://queue.fal.run", timeout=10) + return resp.status_code < 500 + except Exception: + return False + + +fal_ai_adapter = FalAIAdapter() diff --git a/services/image_providers/gemini_image.py b/services/image_providers/gemini_image.py new file mode 100644 index 00000000..19bc0728 --- /dev/null +++ b/services/image_providers/gemini_image.py @@ -0,0 +1,144 @@ +""" +Gemini Image Adapter — port from 9router gemini.js. + +Google Gemini image generation via Imagen. +Uses Gemini API: https://generativelanguage.googleapis.com/v1beta/models/ +""" + +from __future__ import annotations + +import base64 +from typing import Any + +from curl_cffi import requests + +from services.image_providers._base import BaseImageAdapter, now_sec +from utils.log import logger + + +class GeminiImageAdapter(BaseImageAdapter): + """Gemini Imagen image generation adapter. + + Model format: gemini-image/imagen-3.0-generate-001 + Uses Gemini generateContent API with image generation config. + Supports API key rotation from api_keys array. + """ + + BASE_URL = "https://generativelanguage.googleapis.com/v1beta/models" + _key_index: int = 0 + + def _get_api_keys(self, credentials: dict[str, Any] | None) -> list[str]: + """Get all available API keys from credentials.""" + if not credentials or not isinstance(credentials, dict): + return [] + keys = credentials.get("apiKeys") or credentials.get("api_keys") or [] + if isinstance(keys, list) and keys: + return [str(k) for k in keys if k] + single = str(credentials.get("apiKey") or credentials.get("api_key") or "") + return [single] if single else [] + + def build_url(self, model: str, credentials: dict[str, Any] | None, key_index: int = 0) -> str: + api_key = "" + if credentials and isinstance(credentials, dict): + keys = self._get_api_keys(credentials) + if keys: + api_key = keys[key_index % len(keys)] + return f"{self.BASE_URL}/{model}:generateContent?key={api_key}" + + def get_key_count(self, credentials: dict[str, Any] | None) -> int: + return len(self._get_api_keys(credentials)) + + # Size → aspect ratio mapping (OpenAI format → Gemini format) + _SIZE_TO_RATIO = { + # 16:9 + "1792x1024": "16:9", "1344x768": "16:9", + # 9:16 + "1024x1792": "9:16", "768x1344": "9:16", + # 1:1 + "1024x1024": "1:1", "768x768": "1:1", "512x512": "1:1", "256x256": "1:1", + # 4:3 + "1792x1344": "4:3", "1200x896": "4:3", "1024x768": "4:3", + # 3:2 + "1536x1024": "3:2", "1264x848": "3:2", + # 3:4 + "768x1024": "3:4", "896x1200": "3:4", + } + + def build_body(self, model: str, body: dict[str, Any]) -> dict[str, Any]: + prompt = str(body.get("prompt") or "") + images = body.get("images") or [] + n = max(1, min(4, int(body.get("n") or 1))) + size = str(body.get("size") or "") + + parts = [{"text": prompt}] + for img in images: + if isinstance(img, bytes): + import base64 as b64 + parts.append({"inlineData": {"mimeType": "image/png", "data": b64.b64encode(img).decode()}}) + elif isinstance(img, str) and img.startswith("data:"): + header, data = img.split(",", 1) + mime = header.split(";")[0].replace("data:", "") + parts.append({"inlineData": {"mimeType": mime, "data": data}}) + + gen_config: dict[str, Any] = { + "responseModalities": ["TEXT", "IMAGE"], + } + + # Note: generateContent does NOT support responseFormat. + # Aspect ratio and image size are controlled via model-specific parameters + # that vary by model version. Default is model-dependent. + + return { + "contents": [{"parts": parts}], + "generationConfig": gen_config, + } + + def build_headers( + self, + credentials: dict[str, Any] | None, + request_body: dict[str, Any], + model: str, + body: dict[str, Any], + ) -> dict[str, str]: + return {"Content-Type": "application/json"} + + def parse_response(self, response: Any) -> dict[str, Any] | None: + """Extract inline image data from Gemini response.""" + if not hasattr(response, "json"): + return None + + data = response.json() + + # Check for error + if "error" in data: + err = data["error"] + raise RuntimeError(f"Gemini API error {err.get('status','')}: {err.get('message','')[:200]}") + + images = [] + + candidates = data.get("candidates") or [] + for candidate in candidates: + content = candidate.get("content") or {} + parts = content.get("parts") or [] + for part in parts: + if "inlineData" in part: + inline = part["inlineData"] + b64 = inline.get("data") or "" + if b64: + images.append({"b64_json": b64}) + + return {"data": images} if images else None + + def normalize(self, parsed: dict[str, Any], body: dict[str, Any]) -> dict[str, Any]: + data = parsed.get("data") or [] + return {"created": now_sec(), "data": data} + + def test_connection(self, credentials: dict[str, Any] | None = None) -> bool: + try: + resp = requests.get("https://generativelanguage.googleapis.com", timeout=10) + return resp.status_code < 500 + except Exception: + return False + + +gemini_image_adapter = GeminiImageAdapter() diff --git a/services/image_providers/huggingface.py b/services/image_providers/huggingface.py new file mode 100644 index 00000000..41e93be6 --- /dev/null +++ b/services/image_providers/huggingface.py @@ -0,0 +1,70 @@ +""" +HuggingFace Inference API Adapter — port from 9router huggingface.js. + +Free tier available for many models (e.g., black-forest-labs/FLUX.1-schnell). +""" + +from __future__ import annotations + +import base64 +from typing import Any + +from curl_cffi import requests + +from services.image_providers._base import BaseImageAdapter, now_sec +from utils.log import logger + + +class HuggingFaceAdapter(BaseImageAdapter): + """HuggingFace Inference API adapter. + + Supports free-tier models with optional API token. + Model format: huggingface/owner/model-name + """ + + BASE_URL = "https://api-inference.huggingface.co/models" + + def build_url(self, model: str, credentials: dict[str, Any] | None) -> str: + return f"{self.BASE_URL}/{model}" + + def build_body(self, model: str, body: dict[str, Any]) -> dict[str, Any]: + prompt = str(body.get("prompt") or "") + return {"inputs": prompt} + + def build_headers( + self, + credentials: dict[str, Any] | None, + request_body: dict[str, Any], + model: str, + body: dict[str, Any], + ) -> dict[str, str]: + headers = {"Content-Type": "application/json"} + api_key = "" + if credentials and isinstance(credentials, dict): + api_key = str(credentials.get("apiKey") or credentials.get("accessToken") or "") + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + return headers + + def parse_response(self, response: Any) -> dict[str, Any] | None: + # HuggingFace returns raw image bytes + if hasattr(response, "content") and response.headers.get("content-type", "").startswith("image/"): + return {"image_bytes": response.content} + return None + + def normalize(self, parsed: dict[str, Any], body: dict[str, Any]) -> dict[str, Any]: + image_bytes = parsed.get("image_bytes") + if image_bytes and isinstance(image_bytes, bytes): + b64 = base64.b64encode(image_bytes).decode("ascii") + return {"created": now_sec(), "data": [{"b64_json": b64}]} + return {"created": now_sec(), "data": []} + + def test_connection(self, credentials: dict[str, Any] | None = None) -> bool: + try: + resp = requests.get(f"{self.BASE_URL}/black-forest-labs/FLUX.1-schnell", timeout=10) + return resp.status_code < 500 + except Exception: + return False + + +huggingface_adapter = HuggingFaceAdapter() diff --git a/services/image_providers/nvidia_nim_image.py b/services/image_providers/nvidia_nim_image.py new file mode 100644 index 00000000..18e7621e --- /dev/null +++ b/services/image_providers/nvidia_nim_image.py @@ -0,0 +1,159 @@ +""" +NVIDIA NIM Image Generation Adapter. + +Endpoint: https://ai.api.nvidia.com/v1/genai/{model} +Auth: Bearer token from build.nvidia.com +Format: Custom request/response — needs conversion to/from OpenAI format. + +Example model: black-forest-labs/flux.2-klein-4b +""" + +from __future__ import annotations + +import base64 +from typing import Any + +from curl_cffi import requests + +from services.image_providers._base import ( + BaseImageAdapter, + now_sec, + size_to_width_height, +) +from services.config import config +from utils.log import logger + + +class NvidiaNimImageAdapter(BaseImageAdapter): + """NVIDIA NIM Image Generation adapter. + + Uses NVIDIA's image generation endpoint (different from chat endpoint). + """ + + BASE_URL = "https://ai.api.nvidia.com/v1/genai" + + def _get_keys(self) -> list[str]: + cfg = config.data.get("providers") or {} + nv_cfg = cfg.get("nvidia_nim") or {} + single = str(nv_cfg.get("api_key") or "").strip() + multi = nv_cfg.get("api_keys") or [] + if not isinstance(multi, list): + multi = [] + keys = [k.strip() for k in multi if k.strip()] + if single and single not in keys: + keys.insert(0, single) + return keys + + def build_url(self, model: str, credentials: dict[str, Any] | None) -> str: + return f"{self.BASE_URL}/{model}" + + def build_body(self, model: str, body: dict[str, Any]) -> dict[str, Any]: + prompt = str(body.get("prompt") or "") + size = str(body.get("size") or "1792x1024") + w, h = size_to_width_height(size) + + return { + "prompt": prompt, + "width": w, + "height": h, + "seed": body.get("seed", 0), + "steps": body.get("steps", 4), + } + + def build_headers( + self, + credentials: dict[str, Any] | None, + request_body: dict[str, Any], + model: str, + body: dict[str, Any], + ) -> dict[str, str]: + # Use API key from provider config + keys = self._get_keys() + api_key = keys[0] if keys else "" + if credentials and isinstance(credentials, dict): + api_key = str(credentials.get("apiKey") or credentials.get("accessToken") or api_key) + return { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "Accept": "application/json", + } + + def parse_response(self, response: Any) -> dict[str, Any] | None: + """Parse NVIDIA image gen response → OpenAI image format.""" + if not hasattr(response, "json"): + return None + + try: + data = response.json() + except Exception as exc: + logger.error({"event": "nvidia_image_parse_error", "error": str(exc)}) + return None + + # Response format: {"artifacts":[{"base64":"..."}]} or {"image":"..."} + image_b64 = "" + + # NVIDIA returns artifacts array with base64 + artifacts = data.get("artifacts") or [] + if isinstance(artifacts, list) and artifacts: + first = artifacts[0] + if isinstance(first, dict): + image_b64 = first.get("base64") or first.get("image") or "" + elif isinstance(first, str): + image_b64 = first + + if not image_b64: + # Try direct image field + image_b64 = data.get("image") or "" + + if not image_b64: + images = data.get("images") or data.get("data") or [] + if isinstance(images, list) and images: + first = images[0] + if isinstance(first, dict): + image_b64 = first.get("image") or first.get("b64_json") or first.get("url") or first.get("base64") or "" + elif isinstance(first, str): + image_b64 = first + + if not image_b64: + logger.error({"event": "nvidia_image_no_data", "keys": list(data.keys())[:5]}) + return None + + # If it's a URL, convert to base64 + if image_b64.startswith("http"): + from services.image_providers._base import url_to_base64 + image_b64 = url_to_base64(image_b64) + + return {"data": [{"b64_json": image_b64}]} + + def normalize(self, parsed: dict[str, Any], body: dict[str, Any]) -> dict[str, Any]: + """Convert parsed result → OpenAI image response format.""" + data = parsed.get("data") or [] + normalized_data = [] + for item in data: + b64 = item.get("b64_json") or "" + if b64: + if not b64.startswith("data:"): + b64 = f"data:image/png;base64,{b64}" + normalized_data.append({"b64_json": b64, "revised_prompt": str(body.get("prompt") or "")}) + + if not normalized_data: + return {"created": now_sec(), "data": []} + + return {"created": now_sec(), "data": normalized_data} + + def test_connection(self, credentials: dict[str, Any] | None = None) -> bool: + """Test connectivity to NVIDIA image gen endpoint.""" + try: + keys = self._get_keys() + api_key = keys[0] if keys else "" + resp = requests.get( + "https://integrate.api.nvidia.com/v1/models", + headers={"Authorization": f"Bearer {api_key}"}, + timeout=10, + ) + return resp.status_code == 200 + except Exception: + return False + + +nvidia_nim_image_adapter = NvidiaNimImageAdapter() diff --git a/services/image_providers/sdwebui.py b/services/image_providers/sdwebui.py new file mode 100644 index 00000000..f39b3a57 --- /dev/null +++ b/services/image_providers/sdwebui.py @@ -0,0 +1,78 @@ +""" +SD WebUI Adapter — port from 9router sdwebui.js. + +AUTOMATIC1111 Stable Diffusion Web UI at localhost:7860. +NoAuth — completely free, runs on local GPU. +""" + +from __future__ import annotations + +import base64 +import json +from typing import Any + +from curl_cffi import requests + +from services.image_providers._base import BaseImageAdapter, now_sec, size_to_width_height +from utils.log import logger + + +class SDWebUIAdapter(BaseImageAdapter): + """Stable Diffusion Web UI (AUTOMATIC1111) adapter. + + NoAuth — runs locally at http://localhost:7860. + """ + + no_auth = True + + def __init__(self, base_url: str = "http://localhost:7860"): + self.base_url = base_url.rstrip("/") + + def build_url(self, model: str, credentials: dict[str, Any] | None) -> str: + return f"{self.base_url}/sdapi/v1/txt2img" + + def build_body(self, model: str, body: dict[str, Any]) -> dict[str, Any]: + prompt = str(body.get("prompt") or "") + n = max(1, min(4, int(body.get("n") or 1))) + size = str(body.get("size") or "1792x1024") + w, h = size_to_width_height(size) + + return { + "prompt": prompt, + "negative_prompt": "", + "width": w, + "height": h, + "steps": 20, + "batch_size": n, + "cfg_scale": 7, + "sampler_name": "Euler a", + } + + def build_headers( + self, + credentials: dict[str, Any] | None, + request_body: dict[str, Any], + model: str, + body: dict[str, Any], + ) -> dict[str, str]: + return {"Content-Type": "application/json"} + + def normalize(self, parsed: dict[str, Any], body: dict[str, Any]) -> dict[str, Any]: + # SD WebUI returns {"images": ["base64...", ...]} + images = parsed.get("images") or [] + data = [ + {"b64_json": img} + for img in images + if isinstance(img, str) + ] + return {"created": now_sec(), "data": data} + + def test_connection(self, credentials: dict[str, Any] | None = None) -> bool: + try: + resp = requests.get(f"{self.base_url}/sdapi/v1/sd-models", timeout=5) + return resp.status_code == 200 + except Exception: + return False + + +sdwebui_adapter = SDWebUIAdapter() diff --git a/services/image_providers/stability.py b/services/image_providers/stability.py new file mode 100644 index 00000000..7dbc6524 --- /dev/null +++ b/services/image_providers/stability.py @@ -0,0 +1,91 @@ +""" +Stability AI Adapter — port from 9router stabilityAi.js. + +Stability AI v2 API: https://api.stability.ai/v2beta/stable-image/generate/ +Models: ultra, sd3, core +""" + +from __future__ import annotations + +import base64 +from typing import Any + +from curl_cffi import requests + +from services.image_providers._base import BaseImageAdapter, now_sec +from utils.log import logger + + +class StabilityAIAdapter(BaseImageAdapter): + """Stability AI v2 adapter. + + Model format: stability/sd3, stability/ultra, stability/core + """ + + BASE_URL = "https://api.stability.ai/v2beta/stable-image/generate" + + # Model name → API endpoint + ENDPOINT_MAP = { + "ultra": "ultra", + "sd3": "sd3", + "core": "core", + } + + def build_url(self, model: str, credentials: dict[str, Any] | None) -> str: + endpoint = self.ENDPOINT_MAP.get(model, "sd3") + return f"{self.BASE_URL}/{endpoint}" + + def build_body(self, model: str, body: dict[str, Any]) -> dict[str, Any]: + prompt = str(body.get("prompt") or "") + size = str(body.get("size") or "1792x1024") + + aspect_ratio_map = { + "1024x1024": "1:1", + "1792x1024": "16:9", + "1024x1792": "9:16", + "1280x896": "4:3", + "896x1280": "3:4", + } + + return { + "prompt": prompt, + "aspect_ratio": aspect_ratio_map.get(size, "16:9"), + "output_format": "png", + } + + def build_headers( + self, + credentials: dict[str, Any] | None, + request_body: dict[str, Any], + model: str, + body: dict[str, Any], + ) -> dict[str, str]: + api_key = "" + if credentials and isinstance(credentials, dict): + api_key = str(credentials.get("apiKey") or credentials.get("accessToken") or "") + return { + "Authorization": f"Bearer {api_key}", + "Accept": "image/*", + } + + def parse_response(self, response: Any) -> dict[str, Any] | None: + if hasattr(response, "content") and response.headers.get("content-type", "").startswith("image/"): + return {"image_bytes": response.content} + return None + + def normalize(self, parsed: dict[str, Any], body: dict[str, Any]) -> dict[str, Any]: + image_bytes = parsed.get("image_bytes") + if image_bytes and isinstance(image_bytes, bytes): + b64 = base64.b64encode(image_bytes).decode("ascii") + return {"created": now_sec(), "data": [{"b64_json": b64}]} + return {"created": now_sec(), "data": []} + + def test_connection(self, credentials: dict[str, Any] | None = None) -> bool: + try: + resp = requests.get("https://api.stability.ai", timeout=10) + return resp.status_code < 500 + except Exception: + return False + + +stability_adapter = StabilityAIAdapter() diff --git a/services/image_providers/veo_video.py b/services/image_providers/veo_video.py new file mode 100644 index 00000000..88da1a50 --- /dev/null +++ b/services/image_providers/veo_video.py @@ -0,0 +1,194 @@ +""" +Veo Video Adapter — Google Veo 3.1 video generation. + +Endpoint: :predictLongRunning (async operation with polling) +Model: veo-3.1-generate-preview +Supports: text→video, image→video, video extension, reference images +""" + +from __future__ import annotations + +import base64 +import time +from typing import Any + +from curl_cffi import requests + +from utils.log import logger + +VEO_BASE = "https://generativelanguage.googleapis.com/v1beta/models" +VEO_MODEL = "veo-3.1-generate-preview" +VEO_POLL_INTERVAL = 10 # seconds +VEO_MAX_WAIT = 600 # 10 minutes max + + +class VeoVideoAdapter: + """Google Veo 3.1 video generation adapter. + + Model format: veo/veo-3.1-generate-preview + Uses Veo predictLongRunning API with operation polling. + """ + + def __init__(self): + self._key_index = 0 + + def _get_api_keys(self, credentials: dict[str, Any] | None) -> list[str]: + if not credentials or not isinstance(credentials, dict): + return [] + keys = credentials.get("apiKeys") or credentials.get("api_keys") or [] + if isinstance(keys, list) and keys: + return [str(k) for k in keys if k] + single = str(credentials.get("apiKey") or credentials.get("api_key") or "") + return [single] if single else [] + + def get_key_count(self, credentials: dict[str, Any] | None) -> int: + return len(self._get_api_keys(credentials)) + + def _build_url(self, credentials: dict[str, Any] | None, key_index: int = 0) -> str: + keys = self._get_api_keys(credentials) + api_key = keys[key_index % len(keys)] if keys else "" + return f"{VEO_BASE}/{VEO_MODEL}:predictLongRunning?key={api_key}" + + def _build_body(self, body: dict[str, Any]) -> dict[str, Any]: + """Build Veo API request body.""" + prompt = str(body.get("prompt") or "") + instance: dict[str, Any] = {"prompt": prompt} + + # Optional: input image for image→video + image_b64 = body.get("image") + if image_b64: + instance["image"] = { + "inlineData": {"mimeType": "image/png", "data": image_b64} + } + + # Optional: last frame for interpolation + last_frame = body.get("last_frame") + if last_frame: + instance["lastFrame"] = { + "inlineData": {"mimeType": "image/png", "data": last_frame} + } + + request: dict[str, Any] = {"instances": [instance]} + + # Parameters + params: dict[str, Any] = {} + aspect_ratio = body.get("aspect_ratio") or "16:9" + if aspect_ratio: + params["aspectRatio"] = aspect_ratio + + duration = body.get("duration") + if duration: + params["durationSeconds"] = duration + + resolution = body.get("resolution") + if resolution: + params["resolution"] = resolution + + if params: + request["parameters"] = params + + return request + + def generate( + self, + body: dict[str, Any], + credentials: dict[str, Any] | None, + ) -> dict[str, Any]: + """Generate video — submit, poll, download. + + Returns: + {"data": [{"b64_json": "