diff --git a/api/ai.py b/api/ai.py index 34d36e08..d95137d4 100644 --- a/api/ai.py +++ b/api/ai.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, ConfigDict, Field from api.support import require_identity, resolve_image_base_url +from utils.helper import normalize_json_edit_images, parse_image_count from services.content_filter import check_request, request_text from services.log_service import LoggedCall from services.protocol import ( @@ -27,6 +28,36 @@ class ImageGenerationRequest(BaseModel): stream: bool | None = None +class ImageEditJsonRequest(BaseModel): + model_config = ConfigDict(extra="allow") + prompt: str = Field(..., min_length=1) + model: str | None = None + n: object = 1 + size: str | None = None + response_format: str | None = None + stream: bool | None = None + image: object | None = None + images: object | None = None + + +def _is_json_request(request: Request) -> bool: + return "application/json" in request.headers.get("content-type", "").lower() + + +async def _collect_json_edit_payload( + request: Request, +) -> tuple[str, str, object, str | None, str, bool | None, list[tuple[bytes, str, str]]]: + try: + body = ImageEditJsonRequest.model_validate(await request.json()) + except Exception as exc: + raise HTTPException(status_code=400, detail={"error": "invalid image edit JSON request"}) from exc + prompt = body.prompt + model = body.model or "gpt-image-2" + response_format = body.response_format or "b64_json" + images = normalize_json_edit_images(image=body.image, images=body.images) + return prompt, model, body.n, body.size, response_format, body.stream, images + + class ChatCompletionRequest(BaseModel): model_config = ConfigDict(extra="allow") model: str | None = None @@ -92,7 +123,7 @@ async def edit_images( authorization: str | None = Header(default=None), image: list[UploadFile] | None = File(default=None), image_list: list[UploadFile] | None = File(default=None, alias="image[]"), - prompt: str = Form(...), + prompt: str | None = Form(default=None), model: str = Form(default="gpt-image-2"), n: int = Form(default=1), size: str | None = Form(default=None), @@ -100,19 +131,24 @@ async def edit_images( stream: bool | None = Form(default=None), ): identity = require_identity(authorization) + if _is_json_request(request): + prompt, model, n, size, response_format, stream, images = await _collect_json_edit_payload(request) + else: + if not prompt: + raise HTTPException(status_code=422, detail={"error": "prompt is required"}) + uploads = [*(image or []), *(image_list or [])] + if not uploads: + raise HTTPException(status_code=400, detail={"error": "image file is required"}) + images: list[tuple[bytes, str, str]] = [] + for upload in uploads: + image_data = await upload.read() + if not image_data: + raise HTTPException(status_code=400, detail={"error": "image file is empty"}) + images.append((image_data, upload.filename or "image.png", upload.content_type or "image/png")) + + n = parse_image_count(n) call = LoggedCall(identity, "/v1/images/edits", model, "图生图", request_text=prompt) - if n < 1 or n > 4: - raise HTTPException(status_code=400, detail={"error": "n must be between 1 and 4"}) await filter_or_log(call, prompt) - uploads = [*(image or []), *(image_list or [])] - if not uploads: - raise HTTPException(status_code=400, detail={"error": "image file is required"}) - images: list[tuple[bytes, str, str]] = [] - for upload in uploads: - image_data = await upload.read() - if not image_data: - raise HTTPException(status_code=400, detail={"error": "image file is empty"}) - images.append((image_data, upload.filename or "image.png", upload.content_type or "image/png")) payload = { "prompt": prompt, "images": images, diff --git a/test/test_v1_images_edits_json.py b/test/test_v1_images_edits_json.py new file mode 100644 index 00000000..c3d83502 --- /dev/null +++ b/test/test_v1_images_edits_json.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +import base64 +import os +import unittest +from unittest import mock + +os.environ.setdefault("CHATGPT2API_AUTH_KEY", "chatgpt2api") + +from fastapi import FastAPI +from fastapi.testclient import TestClient + +import api.ai as ai_module + +AUTH_HEADERS = {"Authorization": "Bearer chatgpt2api"} +PNG_DATA_URL = "data:image/png;base64," + base64.b64encode(b"fake-png").decode("ascii") +JPEG_DATA_URL = "data:image/jpeg;base64," + base64.b64encode(b"fake-jpeg").decode("ascii") + + +class ImageEditsJsonApiTests(unittest.TestCase): + def setUp(self): + self.calls = [] + + def fake_handle(payload): + self.calls.append(payload) + return {"created": 1, "data": [{"b64_json": "ZmFrZQ=="}]} + + self.handle_patcher = mock.patch.object(ai_module.openai_v1_image_edit, "handle", fake_handle) + self.filter_patcher = mock.patch.object(ai_module, "filter_or_log", mock.AsyncMock()) + self.handle_patcher.start() + self.filter_patcher.start() + self.addCleanup(self.handle_patcher.stop) + self.addCleanup(self.filter_patcher.stop) + + app = FastAPI() + app.include_router(ai_module.create_router()) + self.client = TestClient(app) + + def test_json_model_omitted_uses_existing_default_logic(self): + response = self.client.post("/v1/images/edits", headers=AUTH_HEADERS, json={"prompt": "未传 model", "image": PNG_DATA_URL}) + self.assertEqual(response.status_code, 200, response.text) + self.assertEqual(self.calls[0]["model"], "gpt-image-2") + + def test_json_model_is_not_overwritten_when_provided(self): + response = self.client.post( + "/v1/images/edits", + headers=AUTH_HEADERS, + json={"model": "codex-gpt-image-2", "prompt": "保留 model", "image": PNG_DATA_URL}, + ) + self.assertEqual(response.status_code, 200, response.text) + self.assertEqual(self.calls[0]["model"], "codex-gpt-image-2") + + def test_image_edit_accepts_json_image_url(self): + response = self.client.post( + "/v1/images/edits", + headers=AUTH_HEADERS, + json={ + "model": "gpt-image-2", + "prompt": "把图片改成夜景风格", + "n": 1, + "size": "1024x1536", + "response_format": "b64_json", + "images": [{"image_url": PNG_DATA_URL}], + }, + ) + self.assertEqual(response.status_code, 200, response.text) + payload = self.calls[0] + self.assertEqual(payload["images"], [(b"fake-png", "image_1.png", "image/png")]) + self.assertEqual(payload["size"], "1024x1536") + + def test_image_edit_accepts_json_multiple_images_and_b64_json(self): + response = self.client.post( + "/v1/images/edits", + headers=AUTH_HEADERS, + json={ + "prompt": "把两张图合成海报", + "images": [ + PNG_DATA_URL, + {"b64_json": base64.b64encode(b"raw-jpeg").decode("ascii"), "mime_type": "image/jpeg", "filename": "two.jpg"}, + {"image_url": {"url": JPEG_DATA_URL}}, + ], + }, + ) + self.assertEqual(response.status_code, 200, response.text) + self.assertEqual(self.calls[0]["images"], [ + (b"fake-png", "image_1.png", "image/png"), + (b"raw-jpeg", "two.jpg", "image/jpeg"), + (b"fake-jpeg", "image_3.jpg", "image/jpeg"), + ]) + + def test_image_edit_keeps_original_multipart_multiple_image_logic(self): + response = self.client.post( + "/v1/images/edits", + headers=AUTH_HEADERS, + data={"prompt": "multipart 多图仍然可用", "model": "gpt-image-2", "n": "1"}, + files=[ + ("image", ("one.png", b"one", "image/png")), + ("image", ("two.jpg", b"two", "image/jpeg")), + ("image[]", ("three.webp", b"three", "image/webp")), + ], + ) + self.assertEqual(response.status_code, 200, response.text) + self.assertEqual(self.calls[0]["images"], [ + (b"one", "one.png", "image/png"), + (b"two", "two.jpg", "image/jpeg"), + (b"three", "three.webp", "image/webp"), + ]) + + def test_image_edit_rejects_json_without_image(self): + response = self.client.post("/v1/images/edits", headers=AUTH_HEADERS, json={"prompt": "缺少图片"}) + self.assertEqual(response.status_code, 400, response.text) + self.assertIn("image file is required", response.text) + + def test_image_edit_rejects_remote_json_url(self): + response = self.client.post( + "/v1/images/edits", + headers=AUTH_HEADERS, + json={"prompt": "不允许远程拉图", "images": [{"image_url": "https://example.com/a.png"}]}, + ) + self.assertEqual(response.status_code, 400, response.text) + self.assertIn("remote image URLs are not supported", response.text) + + def test_image_edit_rejects_json_n_out_of_range(self): + response = self.client.post("/v1/images/edits", headers=AUTH_HEADERS, json={"prompt": "n 越界", "n": 5, "image": PNG_DATA_URL}) + self.assertEqual(response.status_code, 400, response.text) + self.assertFalse(self.calls) + + +if __name__ == "__main__": + unittest.main() diff --git a/utils/helper.py b/utils/helper.py index f5546568..ddc5930e 100644 --- a/utils/helper.py +++ b/utils/helper.py @@ -14,6 +14,81 @@ IMAGE_MODELS = {"gpt-image-2", "codex-gpt-image-2"} OUTPUT_DIR = Path(__file__).resolve().parent / "output" +SUPPORTED_JSON_IMAGE_MIME_TYPES = {"image/png", "image/jpeg", "image/jpg", "image/webp", "image/gif"} +MAX_JSON_IMAGE_BYTES = 10 * 1024 * 1024 +MAX_JSON_EDIT_IMAGES = 10 +DATA_URL_IMAGE_RE = re.compile(r"^data:(?P[-+./\w]+);base64,(?P.*)$", re.DOTALL) + + +def _image_extension(mime_type: str) -> str: + image_type = mime_type.split("/", 1)[1].split(";", 1)[0].lower() if "/" in mime_type else "png" + return "jpg" if image_type == "jpeg" else image_type or "png" + + +def _decode_json_image_string(value: str, index: int, filename: str | None = None, mime_type: str | None = None) -> tuple[bytes, str, str]: + text = value.strip() + if not text: + raise HTTPException(status_code=400, detail={"error": "image file is empty"}) + match = DATA_URL_IMAGE_RE.match(text) + if match: + resolved_mime = (match.group("mime") or "image/png").lower() + encoded = match.group("data") + else: + if text.startswith(("http://", "https://")): + raise HTTPException(status_code=400, detail={"error": "remote image URLs are not supported"}) + resolved_mime = (mime_type or "image/png").lower() + encoded = text + if resolved_mime == "image/jpg": + resolved_mime = "image/jpeg" + if resolved_mime not in SUPPORTED_JSON_IMAGE_MIME_TYPES: + raise HTTPException(status_code=400, detail={"error": "unsupported image mime type"}) + try: + image_data = base64.b64decode(encoded, validate=True) + except Exception as exc: + raise HTTPException(status_code=400, detail={"error": "invalid base64 image data"}) from exc + if not image_data: + raise HTTPException(status_code=400, detail={"error": "image file is empty"}) + if len(image_data) > MAX_JSON_IMAGE_BYTES: + raise HTTPException(status_code=400, detail={"error": "image file is too large"}) + return image_data, filename or f"image_{index}.{_image_extension(resolved_mime)}", resolved_mime + + +def _extract_json_image_value(item: object) -> tuple[str, str | None, str | None]: + if isinstance(item, str): + return item, None, None + if not isinstance(item, dict): + raise HTTPException(status_code=400, detail={"error": "image entry must be a base64 string or object"}) + filename = str(item.get("filename") or item.get("file_name") or "").strip() or None + mime_type = str(item.get("mime_type") or item.get("mimeType") or "").strip() or None + value = item.get("b64_json") or item.get("base64") + if not value: + image_url = item.get("image_url") or item.get("url") + if isinstance(image_url, dict): + filename = filename or str(image_url.get("filename") or image_url.get("file_name") or "").strip() or None + mime_type = mime_type or str(image_url.get("mime_type") or image_url.get("mimeType") or "").strip() or None + value = image_url.get("url") or image_url.get("image_url") + else: + value = image_url + if not isinstance(value, str) or not value.strip(): + raise HTTPException(status_code=400, detail={"error": "image entry must include image data"}) + return value, filename, mime_type + + +def normalize_json_edit_images(image: object = None, images: object = None) -> list[tuple[bytes, str, str]]: + raw_images = images if images is not None else image + if raw_images is None: + raise HTTPException(status_code=400, detail={"error": "image file is required"}) + entries = raw_images if isinstance(raw_images, list) else [raw_images] + if not entries: + raise HTTPException(status_code=400, detail={"error": "image file is required"}) + if len(entries) > MAX_JSON_EDIT_IMAGES: + raise HTTPException(status_code=400, detail={"error": f"images supports up to {MAX_JSON_EDIT_IMAGES} items"}) + normalized = [] + for index, item in enumerate(entries, start=1): + value, filename, mime_type = _extract_json_image_value(item) + normalized.append(_decode_json_image_string(value, index, filename, mime_type)) + return normalized + def new_uuid() -> str: return str(uuid.uuid4())