Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 48 additions & 12 deletions api/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pydantic import BaseModel, ConfigDict, Field

from api.support import require_identity, resolve_image_base_url
from utils.helper import normalize_json_edit_images, parse_image_count
from services.content_filter import check_request, request_text
from services.log_service import LoggedCall
from services.protocol import (
Expand All @@ -27,6 +28,36 @@ class ImageGenerationRequest(BaseModel):
stream: bool | None = None


class ImageEditJsonRequest(BaseModel):
model_config = ConfigDict(extra="allow")
prompt: str = Field(..., min_length=1)
model: str | None = None
n: object = 1
size: str | None = None
response_format: str | None = None
stream: bool | None = None
image: object | None = None
images: object | None = None


def _is_json_request(request: Request) -> bool:
return "application/json" in request.headers.get("content-type", "").lower()


async def _collect_json_edit_payload(
request: Request,
) -> tuple[str, str, object, str | None, str, bool | None, list[tuple[bytes, str, str]]]:
try:
body = ImageEditJsonRequest.model_validate(await request.json())
except Exception as exc:
raise HTTPException(status_code=400, detail={"error": "invalid image edit JSON request"}) from exc
prompt = body.prompt
model = body.model or "gpt-image-2"
response_format = body.response_format or "b64_json"
images = normalize_json_edit_images(image=body.image, images=body.images)
return prompt, model, body.n, body.size, response_format, body.stream, images


class ChatCompletionRequest(BaseModel):
model_config = ConfigDict(extra="allow")
model: str | None = None
Expand Down Expand Up @@ -92,27 +123,32 @@ async def edit_images(
authorization: str | None = Header(default=None),
image: list[UploadFile] | None = File(default=None),
image_list: list[UploadFile] | None = File(default=None, alias="image[]"),
prompt: str = Form(...),
prompt: str | None = Form(default=None),
model: str = Form(default="gpt-image-2"),
n: int = Form(default=1),
size: str | None = Form(default=None),
response_format: str = Form(default="b64_json"),
stream: bool | None = Form(default=None),
):
identity = require_identity(authorization)
if _is_json_request(request):
prompt, model, n, size, response_format, stream, images = await _collect_json_edit_payload(request)
else:
if not prompt:
raise HTTPException(status_code=422, detail={"error": "prompt is required"})
uploads = [*(image or []), *(image_list or [])]
if not uploads:
raise HTTPException(status_code=400, detail={"error": "image file is required"})
images: list[tuple[bytes, str, str]] = []
for upload in uploads:
image_data = await upload.read()
if not image_data:
raise HTTPException(status_code=400, detail={"error": "image file is empty"})
images.append((image_data, upload.filename or "image.png", upload.content_type or "image/png"))

n = parse_image_count(n)
call = LoggedCall(identity, "/v1/images/edits", model, "图生图", request_text=prompt)
if n < 1 or n > 4:
raise HTTPException(status_code=400, detail={"error": "n must be between 1 and 4"})
await filter_or_log(call, prompt)
uploads = [*(image or []), *(image_list or [])]
if not uploads:
raise HTTPException(status_code=400, detail={"error": "image file is required"})
images: list[tuple[bytes, str, str]] = []
for upload in uploads:
image_data = await upload.read()
if not image_data:
raise HTTPException(status_code=400, detail={"error": "image file is empty"})
images.append((image_data, upload.filename or "image.png", upload.content_type or "image/png"))
payload = {
"prompt": prompt,
"images": images,
Expand Down
130 changes: 130 additions & 0 deletions test/test_v1_images_edits_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from __future__ import annotations

import base64
import os
import unittest
from unittest import mock

os.environ.setdefault("CHATGPT2API_AUTH_KEY", "chatgpt2api")

from fastapi import FastAPI
from fastapi.testclient import TestClient

import api.ai as ai_module

AUTH_HEADERS = {"Authorization": "Bearer chatgpt2api"}
PNG_DATA_URL = "data:image/png;base64," + base64.b64encode(b"fake-png").decode("ascii")
JPEG_DATA_URL = "data:image/jpeg;base64," + base64.b64encode(b"fake-jpeg").decode("ascii")


class ImageEditsJsonApiTests(unittest.TestCase):
def setUp(self):
self.calls = []

def fake_handle(payload):
self.calls.append(payload)
return {"created": 1, "data": [{"b64_json": "ZmFrZQ=="}]}

self.handle_patcher = mock.patch.object(ai_module.openai_v1_image_edit, "handle", fake_handle)
self.filter_patcher = mock.patch.object(ai_module, "filter_or_log", mock.AsyncMock())
self.handle_patcher.start()
self.filter_patcher.start()
self.addCleanup(self.handle_patcher.stop)
self.addCleanup(self.filter_patcher.stop)

app = FastAPI()
app.include_router(ai_module.create_router())
self.client = TestClient(app)

def test_json_model_omitted_uses_existing_default_logic(self):
response = self.client.post("/v1/images/edits", headers=AUTH_HEADERS, json={"prompt": "未传 model", "image": PNG_DATA_URL})
self.assertEqual(response.status_code, 200, response.text)
self.assertEqual(self.calls[0]["model"], "gpt-image-2")

def test_json_model_is_not_overwritten_when_provided(self):
response = self.client.post(
"/v1/images/edits",
headers=AUTH_HEADERS,
json={"model": "codex-gpt-image-2", "prompt": "保留 model", "image": PNG_DATA_URL},
)
self.assertEqual(response.status_code, 200, response.text)
self.assertEqual(self.calls[0]["model"], "codex-gpt-image-2")

def test_image_edit_accepts_json_image_url(self):
response = self.client.post(
"/v1/images/edits",
headers=AUTH_HEADERS,
json={
"model": "gpt-image-2",
"prompt": "把图片改成夜景风格",
"n": 1,
"size": "1024x1536",
"response_format": "b64_json",
"images": [{"image_url": PNG_DATA_URL}],
},
)
self.assertEqual(response.status_code, 200, response.text)
payload = self.calls[0]
self.assertEqual(payload["images"], [(b"fake-png", "image_1.png", "image/png")])
self.assertEqual(payload["size"], "1024x1536")

def test_image_edit_accepts_json_multiple_images_and_b64_json(self):
response = self.client.post(
"/v1/images/edits",
headers=AUTH_HEADERS,
json={
"prompt": "把两张图合成海报",
"images": [
PNG_DATA_URL,
{"b64_json": base64.b64encode(b"raw-jpeg").decode("ascii"), "mime_type": "image/jpeg", "filename": "two.jpg"},
{"image_url": {"url": JPEG_DATA_URL}},
],
},
)
self.assertEqual(response.status_code, 200, response.text)
self.assertEqual(self.calls[0]["images"], [
(b"fake-png", "image_1.png", "image/png"),
(b"raw-jpeg", "two.jpg", "image/jpeg"),
(b"fake-jpeg", "image_3.jpg", "image/jpeg"),
])

def test_image_edit_keeps_original_multipart_multiple_image_logic(self):
response = self.client.post(
"/v1/images/edits",
headers=AUTH_HEADERS,
data={"prompt": "multipart 多图仍然可用", "model": "gpt-image-2", "n": "1"},
files=[
("image", ("one.png", b"one", "image/png")),
("image", ("two.jpg", b"two", "image/jpeg")),
("image[]", ("three.webp", b"three", "image/webp")),
],
)
self.assertEqual(response.status_code, 200, response.text)
self.assertEqual(self.calls[0]["images"], [
(b"one", "one.png", "image/png"),
(b"two", "two.jpg", "image/jpeg"),
(b"three", "three.webp", "image/webp"),
])

def test_image_edit_rejects_json_without_image(self):
response = self.client.post("/v1/images/edits", headers=AUTH_HEADERS, json={"prompt": "缺少图片"})
self.assertEqual(response.status_code, 400, response.text)
self.assertIn("image file is required", response.text)

def test_image_edit_rejects_remote_json_url(self):
response = self.client.post(
"/v1/images/edits",
headers=AUTH_HEADERS,
json={"prompt": "不允许远程拉图", "images": [{"image_url": "https://example.com/a.png"}]},
)
self.assertEqual(response.status_code, 400, response.text)
self.assertIn("remote image URLs are not supported", response.text)

def test_image_edit_rejects_json_n_out_of_range(self):
response = self.client.post("/v1/images/edits", headers=AUTH_HEADERS, json={"prompt": "n 越界", "n": 5, "image": PNG_DATA_URL})
self.assertEqual(response.status_code, 400, response.text)
self.assertFalse(self.calls)


if __name__ == "__main__":
unittest.main()
75 changes: 75 additions & 0 deletions utils/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,81 @@
IMAGE_MODELS = {"gpt-image-2", "codex-gpt-image-2"}
OUTPUT_DIR = Path(__file__).resolve().parent / "output"

SUPPORTED_JSON_IMAGE_MIME_TYPES = {"image/png", "image/jpeg", "image/jpg", "image/webp", "image/gif"}
MAX_JSON_IMAGE_BYTES = 10 * 1024 * 1024
MAX_JSON_EDIT_IMAGES = 10
DATA_URL_IMAGE_RE = re.compile(r"^data:(?P<mime>[-+./\w]+);base64,(?P<data>.*)$", re.DOTALL)


def _image_extension(mime_type: str) -> str:
image_type = mime_type.split("/", 1)[1].split(";", 1)[0].lower() if "/" in mime_type else "png"
return "jpg" if image_type == "jpeg" else image_type or "png"


def _decode_json_image_string(value: str, index: int, filename: str | None = None, mime_type: str | None = None) -> tuple[bytes, str, str]:
text = value.strip()
if not text:
raise HTTPException(status_code=400, detail={"error": "image file is empty"})
match = DATA_URL_IMAGE_RE.match(text)
if match:
resolved_mime = (match.group("mime") or "image/png").lower()
encoded = match.group("data")
else:
if text.startswith(("http://", "https://")):
raise HTTPException(status_code=400, detail={"error": "remote image URLs are not supported"})
resolved_mime = (mime_type or "image/png").lower()
encoded = text
if resolved_mime == "image/jpg":
resolved_mime = "image/jpeg"
if resolved_mime not in SUPPORTED_JSON_IMAGE_MIME_TYPES:
raise HTTPException(status_code=400, detail={"error": "unsupported image mime type"})
try:
image_data = base64.b64decode(encoded, validate=True)
except Exception as exc:
raise HTTPException(status_code=400, detail={"error": "invalid base64 image data"}) from exc
if not image_data:
raise HTTPException(status_code=400, detail={"error": "image file is empty"})
if len(image_data) > MAX_JSON_IMAGE_BYTES:
raise HTTPException(status_code=400, detail={"error": "image file is too large"})
return image_data, filename or f"image_{index}.{_image_extension(resolved_mime)}", resolved_mime


def _extract_json_image_value(item: object) -> tuple[str, str | None, str | None]:
if isinstance(item, str):
return item, None, None
if not isinstance(item, dict):
raise HTTPException(status_code=400, detail={"error": "image entry must be a base64 string or object"})
filename = str(item.get("filename") or item.get("file_name") or "").strip() or None
mime_type = str(item.get("mime_type") or item.get("mimeType") or "").strip() or None
value = item.get("b64_json") or item.get("base64")
if not value:
image_url = item.get("image_url") or item.get("url")
if isinstance(image_url, dict):
filename = filename or str(image_url.get("filename") or image_url.get("file_name") or "").strip() or None
mime_type = mime_type or str(image_url.get("mime_type") or image_url.get("mimeType") or "").strip() or None
value = image_url.get("url") or image_url.get("image_url")
else:
value = image_url
if not isinstance(value, str) or not value.strip():
raise HTTPException(status_code=400, detail={"error": "image entry must include image data"})
return value, filename, mime_type


def normalize_json_edit_images(image: object = None, images: object = None) -> list[tuple[bytes, str, str]]:
raw_images = images if images is not None else image
if raw_images is None:
raise HTTPException(status_code=400, detail={"error": "image file is required"})
entries = raw_images if isinstance(raw_images, list) else [raw_images]
if not entries:
raise HTTPException(status_code=400, detail={"error": "image file is required"})
if len(entries) > MAX_JSON_EDIT_IMAGES:
raise HTTPException(status_code=400, detail={"error": f"images supports up to {MAX_JSON_EDIT_IMAGES} items"})
normalized = []
for index, item in enumerate(entries, start=1):
value, filename, mime_type = _extract_json_image_value(item)
normalized.append(_decode_json_image_string(value, index, filename, mime_type))
return normalized


def new_uuid() -> str:
return str(uuid.uuid4())
Expand Down