Skip to content

Commit 1c65c4a

Browse files
author
Sergio Soto Núñez
committed
Renames services. Move entrypoints to a path
1 parent dc8d840 commit 1c65c4a

24 files changed

+728
-68
lines changed

.env.docker

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
ENV=docker
2-
DB_URL=mongodb://admin:admin@mongodb-service/
3-
DB_NAME=ai_platform
2+
DB_URL=mongodb://admin:admin@mongodb-service:27017/
3+
DB_NAME=ai_platform
4+
REDIS_URL=redis://redis-service:6379/0

.env.local

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
ENV=local
22
DB_URL=mongodb://admin:[email protected]/
3-
DB_NAME=ai_platform
3+
DB_NAME=ai_platform
4+
REDIS_URL=redis://127.0.0.1:6379/0

Dockerfile.api

+3-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ COPY --from=requirements-stage /tmp/requirements.txt /code/requirements.txt
1111
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
1212

1313
COPY ./ai_platform /code/ai_platform
14-
COPY ./api_entrypoint.sh /code/api_entrypoint.sh
14+
COPY ./scripts/api_entrypoint.sh /code/api_entrypoint.sh
1515
RUN chmod +x /code/api_entrypoint.sh
1616

17+
COPY .env.* /code/
18+
1719
CMD ["./api_entrypoint.sh"]

Dockerfile.workers

+4-2
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@ RUN poetry install --no-interaction --no-ansi
2020
COPY ./ai_platform /code/ai_platform
2121
COPY .env.* /code/
2222

23-
COPY ./worker_initialize.sh /code/worker_initialize.sh
24-
COPY ./worker_entrypoint.sh /code/worker_entrypoint.sh
23+
COPY ./scripts/worker_initialize.sh /code/worker_initialize.sh
24+
COPY ./scripts/worker_entrypoint.sh /code/worker_entrypoint.sh
25+
2526
RUN chmod +x /code/worker_entrypoint.sh
2627
RUN chmod +x /code/worker_initialize.sh
28+
2729
# Execute worker initialization script
2830
RUN ./worker_initialize.sh
2931

README.md

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
2+
# Project overview
3+
This project is a simple API that allows you to create, update, delete and list machine learning models. It also allows you to create, update, delete and list machine learning model versions. It is a simple API that allows you to manage machine learning models and their versions.
4+
5+
## FastAPI
6+
7+
## Celery
8+
9+
## Redis
10+
11+
## MongoDB
12+
13+
14+
# Project local setup
15+
Clone the project:
16+
```bash
17+
git clone https://github.com/ssoto/ai-platform-api.git
18+
cd ai-platform-api
19+
```
20+
You are going to need to have [Docker](https://www.docker.com/) installed in your machine or something similar like [Rancher](https://rancher.com/) to run the project.
21+
22+
Then start compiling proper images and running the project:
23+
```bash
24+
docker-compose build --build-arg ENV=docker workers ai-platform-api
25+
docker-compose up -d
26+
```

ai_platform/api/main.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
11
import logging
22
from fastapi import FastAPI, Request
3+
from fastapi.responses import JSONResponse
34
from motor.motor_asyncio import AsyncIOMotorClient
5+
46
from ai_platform.api.public.image_tasks.router import router as image_tasks_router
57
from ai_platform.api.public.images.router import router as images_router
6-
from ai_platform.sandbox.images_creation import startup_pipeline
8+
from ai_platform.task_queue.main import app as celery_app
9+
from ai_platform.utils import is_redis_ok, is_mongo_ok
710
from ai_platform.config import settings
811

912
logging.basicConfig(
1013
level=logging.INFO,
1114
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
15+
handlers=[logging.StreamHandler()]
1216
)
1317

1418

1519
def initialize_app(app: FastAPI):
1620
logging.info("Initializing app...")
17-
# do some initialization here
18-
# startup_pipeline()
1921
app.mongodb_client = AsyncIOMotorClient(settings.DB_URL)
2022
app.mongodb = app.mongodb_client[settings.DB_NAME]
23+
app.celery_app = celery_app
2124
logging.info("App initialized")
2225
yield
2326
app.mongodb_client.close()
@@ -26,14 +29,22 @@ def initialize_app(app: FastAPI):
2629
app = FastAPI(
2730
title="AI Platform",
2831
root_path=settings.API_ROOT_PATH,
32+
debug=True,
2933
lifespan=initialize_app
3034
)
3135

3236

3337
@app.get("/health")
3438
async def health(request: Request):
35-
result = await request.app.mongodb_client.server_info()
36-
return result
39+
40+
result = {
41+
"mongodb": await is_mongo_ok(request.app.mongodb_client),
42+
"redis": is_redis_ok(),
43+
}
44+
return JSONResponse(
45+
content=result,
46+
status_code=200 if all(result.values()) else 500
47+
)
3748

3849
app.include_router(image_tasks_router)
3950
app.include_router(images_router)

ai_platform/api/public/image_tasks/router.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
from fastapi.encoders import jsonable_encoder
55

66
from ai_platform.domain.image_tasks.models import ImageTask
7-
from ai_platform.sandbox.images_creation import generate_image
8-
from ai_platform.domain.image_tasks.use_cases import create_image_task, find_image_task_by_id
9-
from ai_platform.domain.image_repository.use_cases import get_image_url
7+
from ai_platform.domain.image_tasks.use_cases import create_image_task, afind_image_task_by_id
8+
from ai_platform.domain.image_repository.use_cases import get_image_url
9+
from ai_platform.domain.task_queues.use_cases import send_generation_message
10+
1011

1112
router = APIRouter(
1213
prefix="/imageTasks",
@@ -18,7 +19,7 @@
1819
@router.get("/")
1920
async def retrieve(request: Request, id_task: str):
2021

21-
result = await find_image_task_by_id(id_task, request.app.mongodb["image_tasks"])
22+
result = await afind_image_task_by_id(id_task, request.app.mongodb["image_tasks"])
2223
if not result:
2324
return JSONResponse(
2425
status_code=status.HTTP_404_NOT_FOUND,
@@ -40,17 +41,17 @@ async def generate(
4041

4142
image_task = ImageTask(
4243
prompt=prompt,
43-
status="processing"
4444
)
4545
# FIXME: this image service is a local endpoint, it should be a service
4646
image_task.url = get_image_url(image_task.id)
4747
await create_image_task(
4848
image_task,
4949
request.app.mongodb["image_tasks"]
5050
)
51-
52-
background_tasks.add_task(generate_image, image_task)
53-
logging.info(f"Task {image_task.id} added to the queue")
51+
background_tasks.add_task(
52+
send_generation_message, image_task, request.app.celery_app,
53+
)
54+
logging.info(f"Task {image_task.id} send to the queue")
5455
return JSONResponse(
5556
status_code=status.HTTP_201_CREATED,
5657
content=image_task.dict()

ai_platform/api/public/images/router.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1-
from fastapi import APIRouter, Request, Response
2-
from ai_platform.domain.image_tasks.use_cases import find_image_task_by_id
1+
import os.path
2+
3+
from fastapi import APIRouter, Request
4+
from fastapi.responses import Response
5+
from ai_platform.domain.image_tasks.use_cases import afind_image_task_by_id
36

47
router = APIRouter(
58
prefix="/images",
@@ -19,10 +22,16 @@
1922
response_class=Response,
2023
)
2124
async def retrieve(request: Request, image_id: str):
22-
image_task = await find_image_task_by_id(
25+
image_task = await afind_image_task_by_id(
2326
image_id,
2427
request.app.mongodb["image_tasks"]
2528
)
29+
if not os.path.exists(image_task.image_path):
30+
return Response(
31+
status_code=404,
32+
content=f"Image {image_id}.png not found"
33+
)
34+
2635
with open(image_task.image_path, "rb") as image_file:
2736
image_bytes = image_file.read()
2837
return Response(

ai_platform/config/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@ class DatabaseSettings(BaseSettings):
3333
DB_NAME: str
3434

3535

36-
class Settings(CommonSettings, ServerSettings, DatabaseSettings):
36+
class RedisSettings(BaseSettings):
37+
REDIS_URL: str
38+
39+
40+
class Settings(CommonSettings, ServerSettings, DatabaseSettings, RedisSettings):
3741
pass
3842

3943

ai_platform/domain/image_tasks/use_cases.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,20 @@ async def create_image_task(image_task: ImageTask, db_collection):
1010
return await db_collection.insert_one(image_task.dict(by_alias=True))
1111

1212

13-
async def find_image_task_by_id(id_task: str, db_collection):
13+
async def afind_image_task_by_id(id_task: str, db_collection):
1414
result = await db_collection.find_one({"_id": id_task})
1515
if result:
1616
return ImageTask(**result)
1717
return None
1818

1919

20+
def find_image_task_by_id(id_task: str, db_collection):
21+
result = db_collection.find_one({"_id": id_task})
22+
if result:
23+
return ImageTask(**result)
24+
return None
25+
26+
2027
def update_image_task(image_task: ImageTask):
2128
client = None
2229
try:
+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from pydantic import BaseModel, Field
2+
from ai_platform.domain.image_tasks.models import ImageTask
3+
4+
5+
class ImageTaskMesage(BaseModel):
6+
7+
image: ImageTask = Field(..., alias="imageTask")
8+
9+
class Config:
10+
allow_population_by_field_name = True
11+
fields = {
12+
"imageTask": "image"
13+
}
14+
json_encoders = {
15+
ImageTask: lambda v: v.dict(by_alias=True)
16+
}
17+
18+
19+
class TaskQueueMessage(BaseModel):
20+
21+
task: ImageTaskMesage = Field(..., alias="taskQueue")
22+
23+
class Config:
24+
allow_population_by_field_name = True
25+
fields = {
26+
"taskQueue": "task"
27+
}
28+
json_encoders = {
29+
ImageTaskMesage: lambda v: v.dict(by_alias=True)
30+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import logging
2+
from celery import Celery
3+
from ai_platform.domain.image_tasks.use_cases import ImageTask
4+
5+
6+
def send_generation_message(image_task: ImageTask, celery_app: Celery):
7+
result = celery_app.send_task(
8+
"generate_image",
9+
kwargs=image_task.dict(),
10+
queue="default"
11+
)
12+
logging.info(f"Task {result.id} sent to the queue")
13+

ai_platform/task_queue/__init__.py

Whitespace-only changes.

ai_platform/sandbox/images_creation.py renamed to ai_platform/task_queue/images_creation.py

+15-34
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,10 @@
55
import torch
66

77
from ai_platform.domain.image_tasks.models import ImageTask
8-
from ai_platform.domain.image_tasks.use_cases import update_image_task
9-
from ai_platform.domain.image_repository.use_cases import update_image_task
10-
11-
_pipe = None
128

139

1410
def startup_pipeline(only_download=False):
15-
global _pipe
16-
_pipe = DiffusionPipeline.from_pretrained(
11+
pipe = DiffusionPipeline.from_pretrained(
1712
"runwayml/stable-diffusion-v1-5",
1813
torch_dtype=torch.float16,
1914
variant="fp16",
@@ -22,35 +17,29 @@ def startup_pipeline(only_download=False):
2217
return
2318

2419
if 'macOS' in platform.platform():
25-
_pipe.to("mps")
20+
pipe.to("mps")
2621

27-
_pipe.safety_checker = None
28-
_pipe.requires_safety_checker = False
22+
pipe.safety_checker = None
23+
pipe.requires_safety_checker = False
2924

3025
prompt = "fake prompt to wwarmup the pipeline"
3126
logging.info("Warming up the pipeline")
3227
# First-time "warmup" pass if PyTorch version is 1.13
33-
_ = _pipe(prompt, num_inference_steps=1)
28+
_ = pipe(prompt, num_inference_steps=1)
3429
logging.info("Pipeline warmed up")
3530

31+
return pipe
32+
3633

37-
def generate_image(image: ImageTask):
34+
def create_image(pipe: DiffusionPipeline, image: ImageTask):
3835
# Results match those from the CPU device after the warmup pass.
39-
global _pipe
40-
try:
41-
result = _pipe(
42-
image.prompt,
43-
num_inference_steps=image.generation_steps
44-
)
45-
image_file = result.images[0]
46-
logging.info(f"Image generated: {image_file.shape}")
47-
image_file.save(image.image_path)
48-
image.set_completed()
49-
except Exception as e:
50-
image.set_failed(reason=repr(e))
51-
finally:
52-
update_image_task(image)
53-
logging.info(f"Process finished for image {image.id}")
36+
result = pipe(
37+
image.prompt,
38+
num_inference_steps=image.generation_steps
39+
)
40+
image_file = result.images[0]
41+
logging.info(f"Image generated: {image_file}")
42+
return image_file
5443

5544

5645
def parse_args():
@@ -79,14 +68,6 @@ def main():
7968
if args.download_model:
8069
return
8170

82-
# Testing the image generation
83-
image_task = ImageTask(
84-
prompt="A cat in the snow",
85-
generation_steps=10,
86-
)
87-
image_task.set_processing()
88-
generate_image(image_task)
89-
9071

9172
if __name__ == "__main__":
9273
main()

0 commit comments

Comments
 (0)