Skip to content

Commit 83000d1

Browse files
authored
Merge pull request #8 from ssoto/feat/implements_image_generation_improvement
VAE encoder
2 parents 87ac007 + dec372e commit 83000d1

File tree

3 files changed

+54
-15
lines changed

3 files changed

+54
-15
lines changed

ai_platform/api/public/image_tasks/router.py

+28-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import logging
22
from fastapi.responses import JSONResponse
3+
from pydantic import Field, BaseModel, field_validator
34
from fastapi import APIRouter, Request, status, BackgroundTasks
4-
from fastapi.encoders import jsonable_encoder
5+
from typing import Optional
56

67
from ai_platform.domain.image_tasks.models import ImageTask
78
from ai_platform.domain.image_tasks.use_cases import create_image_task, afind_image_task_by_id
@@ -16,6 +17,28 @@
1617
)
1718

1819

20+
class GenerationRequest(BaseModel):
21+
prompt: str = Field(
22+
...,
23+
title="Prompt to generate the image"
24+
)
25+
generation_steps: Optional[int] = Field(
26+
50,
27+
title="Number of steps to generate the image. Values between 1 and 999"
28+
)
29+
seed: Optional[int] = Field(
30+
None,
31+
title="Seed to generate the image"
32+
)
33+
34+
@field_validator("generation_steps") # noqa
35+
@classmethod
36+
def validate_generation_steps(cls, v):
37+
if v < 1 or v > 999:
38+
raise ValueError("Generation steps must be between 1 and 999")
39+
return v
40+
41+
1942
@router.get("/")
2043
async def retrieve(request: Request, id_task: str):
2144

@@ -34,13 +57,13 @@ async def retrieve(request: Request, id_task: str):
3457
@router.post("/")
3558
async def generate(
3659
request: Request,
37-
prompt: str,
60+
body: GenerationRequest,
3861
background_tasks: BackgroundTasks
3962
):
40-
prompt = jsonable_encoder(prompt)
41-
4263
image_task = ImageTask(
43-
prompt=prompt,
64+
prompt=body.prompt,
65+
generation_steps=body.generation_steps,
66+
seed=body.seed
4467
)
4568
# FIXME: this image service is a local endpoint, it should be a service
4669
image_task.url = get_image_url(image_task.id)

ai_platform/domain/image_tasks/models.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class ImageTask(BaseModel):
6161
description="Task status"
6262
)
6363
reason: Optional[str] = Field(
64-
default=None,
64+
None,
6565
title="Reason of failure",
6666
description="Reason of failure"
6767
)
@@ -70,6 +70,11 @@ class ImageTask(BaseModel):
7070
title="Number of steps to generate the image",
7171
description="Number of steps to generate the image"
7272
)
73+
seed: Optional[int] = Field(
74+
None,
75+
title="Seed to generate the image",
76+
description="Seed to generate the image"
77+
)
7378
url: str = Field(
7479
None,
7580
title="URL to download the image",

ai_platform/task_queue/images_creation.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import platform
33

44
from diffusers import DiffusionPipeline
5+
from diffusers.models import AutoencoderKL
6+
from diffusers import StableDiffusionPipeline
57
import torch
68

79
from ai_platform.domain.image_tasks.models import ImageTask
@@ -13,12 +15,16 @@
1315

1416

1517
def startup_pipeline(only_download=False):
16-
# https://huggingface.co/docs/diffusers/tutorials/basic_training
17-
pipe = DiffusionPipeline.from_pretrained(
18-
"runwayml/stable-diffusion-v1-5",
19-
torch_dtype=torch.float16,
20-
variant="fp16",
18+
logger.info(f"Starting pipeline")
19+
logger.info("Loading VAE model")
20+
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
21+
model = "runwayml/stable-diffusion-v1-5"
22+
logger.info(f"Loading model: {model}")
23+
pipe = StableDiffusionPipeline.from_pretrained(
24+
model,
25+
vae=vae
2126
)
27+
2228
if only_download:
2329
return
2430

@@ -40,11 +46,16 @@ def startup_pipeline(only_download=False):
4046

4147

4248
def create_image(pipe: DiffusionPipeline, image: ImageTask):
49+
kwargs = {}
50+
if image.seed is not None:
51+
generator = torch.Generator(device="cpu").manual_seed(image.seed)
52+
kwargs["generator"] = generator
53+
54+
kwargs["prompt"] = image.prompt
55+
kwargs["num_inference_steps"] = image.generation_steps
56+
4357
# Results match those from the CPU device after the warmup pass.
44-
result = pipe(
45-
image.prompt,
46-
num_inference_steps=image.generation_steps
47-
)
58+
result = pipe(**kwargs)
4859
image_file = result.images[0]
4960
logger.info(f"Image generated: {image_file}")
5061
return image_file

0 commit comments

Comments
 (0)