Skip to content

Commit

Permalink
transformer/face_detection: add support for webdataset style format
Browse files Browse the repository at this point in the history
Signed-off-by: Abhishek Gaikwad <[email protected]>
  • Loading branch information
gaikwadabhishek committed Dec 18, 2023
1 parent d733764 commit 1e23f04
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 96 deletions.
8 changes: 6 additions & 2 deletions transformers/face_detection/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
FROM python:slim
# Prior to building this image make you own kaggle_creds.json file
# containing kaggle keys to download dataset
FROM python:3.8-slim

WORKDIR /

# install packages needed for open-cv to work
RUN apt-get update && apt-get -y install gcc ffmpeg libsm6 libxext6 unzip
RUN apt-get update && apt-get -y install gcc ffmpeg libsm6 libxext6 unzip curl

# install python dependencies
COPY ./requirements.txt requirements.txt
Expand All @@ -29,4 +31,6 @@ COPY main.py main.py

ENV PYTHONUNBUFFERED 1

ENV LOG_LEVEL DEBUG

EXPOSE 8000
5 changes: 5 additions & 0 deletions transformers/face_detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Adjust the following parameters in the `pod.yaml` file as per your requirements:
|------------|---------------------------------------------------------------------|---------------|
| `FORMAT` | Image format for processing/storing (png, jpeg, etc.) | "jpeg" |
| `ARG_TYPE` | Local object reading (`fqn`) vs. HTTP request for object retrieval | "" |
| `FILE_FORMAT` | Configure as "tar" for processing datasets in the webdataset format or for handling batches of images packaged in a tarball | "" |

### Setting Up the Face Detection Transformer with AIStore CLI

Expand All @@ -42,6 +43,7 @@ cd transformers/face_detection
# Set FORMAT and ARG_TYPE environment variables
export FORMAT="jpeg"
export ARG_TYPE="" # Or use 'fqn' for local reading
export FILE_FORMAT="" # or use "tar", if using webdataset format

# Define communication type
export COMMUNICATION_TYPE="hpush://"
Expand All @@ -58,4 +60,7 @@ ais etl object <etl-name> ais://src/<image-name>.JPEG dst.JPEG

# For offline (bucket-to-bucket) transformation
ais etl bucket <etl-name> ais://src-bck ais://dst-bck --ext="{jpg:jpg}"

# or, if using webdataset style format
# ais etl bucket <etl-name> ais://src-bck ais://dst-bck --ext="{tar:tar}"
```
98 changes: 82 additions & 16 deletions transformers/face_detection/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,33 @@
Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""
# pylint: disable=missing-class-docstring, missing-function-docstring, missing-module-docstring, broad-exception-caught
# pylint: disable=missing-class-docstring, missing-function-docstring, missing-module-docstring, broad-exception-caught, unused-import
import os
import urllib.parse
import io
import logging

import aiofiles
from fastapi import FastAPI, Request, Depends, Response
from fastapi.logger import logger
import aiohttp # async
import cv2
import numpy as np
import webdataset as wds
from PIL import Image

# logging
gunicorn_logger = logging.getLogger("gunicorn.error")
logger.handlers = gunicorn_logger.handlers
logger.setLevel(logging.DEBUG)

app = FastAPI()

# env vars
host_target = os.environ["AIS_TARGET_URL"]
FORMAT = os.environ["FORMAT"]
arg_type = os.getenv("ARG_TYPE", "")
file_format = os.getenv("FILE_FORMAT", "image")


class HttpClient:
Expand Down Expand Up @@ -55,12 +67,27 @@ async def health():
return b"Running"


cvNet = cv2.dnn.readNetFromCaffe(
MODEL = cv2.dnn.readNetFromCaffe(
"./model/architecture.txt", "./model/weights.caffemodel"
)


async def transform_image(image_bytes: bytes) -> bytes:
def transform_tar(obj_url: str) -> bytes:
dataset = wds.WebDataset(obj_url)
processed_shard = dataset.map_dict(**{f"{FORMAT}": transform_image})

# Write the output to a memory buffer and return the value
buffer = io.BytesIO()
with wds.TarWriter(fileobj=buffer) as dst:
for sample in processed_shard:
dst.write(sample)
buffer.seek(0)
data = buffer.read()
buffer.close()
return data


def transform_image(image_bytes: bytes) -> bytes:
image = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), -1)
image_height, image_width, _ = image.shape
output_image = image.copy()
Expand All @@ -72,8 +99,8 @@ async def transform_image(image_bytes: bytes) -> bytes:
swapRB=False,
crop=False,
)
cvNet.setInput(preprocessed_image)
results = cvNet.forward()
MODEL.setInput(preprocessed_image)
results = MODEL.forward()

for face in results[0][0]:
face_confidence = face[2]
Expand Down Expand Up @@ -109,18 +136,36 @@ async def get_handler(
# Fetch object from AIS target based on the destination/name
# Transform the bytes
# Return the transformed bytes

logger.info("TRANSFORMATION STARTED :: %s", full_path)
if arg_type.lower() == "fqn":
async with aiofiles.open(full_path, "rb") as file:
body = await file.read()
if (
file_format.lower() == "tar"
or file_format.lower() == "wds"
or file_format.lower() == "webdataset"
):
result = transform_tar(full_path)
else:
async with aiofiles.open(full_path, "rb") as file:
body = await file.read()
result = transform_image(body)
else:
object_path = urllib.parse.quote(full_path, safe="@")
object_url = f"{host_target}/{object_path}"
resp = await client.get(object_url)
body = await resp.read()

logger.info("object_url: %s", object_url)
if (
file_format.lower() == "tar"
or file_format.lower() == "wds"
or file_format.lower() == "webdataset"
):
result = transform_tar(object_url)
else:
resp = await client.get(object_url)
body = await resp.read()
result = transform_image(body)

logger.info("TRANSFORMATION COMPLETED :: %s", full_path)
return Response(
content=await transform_image(body),
content=result,
media_type="application/octet-stream",
)

Expand All @@ -134,15 +179,36 @@ async def put_handler(request: Request, full_path: str):
and returns the modified bytes.
"""
# Read bytes from request (request.body)
# Transform the bytes
logger.info("TRANSFORMATION STARTED :: %s", full_path)
if arg_type.lower() == "fqn":
async with aiofiles.open(full_path, "rb") as file:
body = await file.read()
if (
file_format.lower() == "tar"
or file_format.lower() == "wds"
or file_format.lower() == "webdataset"
):
result = transform_tar(full_path)
else:
async with aiofiles.open(full_path, "rb") as file:
body = await file.read()
result = transform_image(body)
else:
if (
file_format.lower() == "tar"
or file_format.lower() == "wds"
or file_format.lower() == "webdataset"
):
# no way to find the url of the object
raise ValueError(
'FILE_FORMAT "tar" requires comm_type=hpush or arg_type=fqn'
)

body = await request.body()
result = transform_image(body)

# Transform the bytes
# Return the transformed bytes
logger.info("TRANSFORMATION COMPLETED :: %s", full_path)
return Response(
content=await transform_image(body),
content=result,
media_type="application/octet-stream",
)
4 changes: 3 additions & 1 deletion transformers/face_detection/pod.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ spec:
ports:
- name: default
containerPort: 8000
command: ["gunicorn", "main:app", "--workers", "5", "--worker-class", "uvicorn.workers.UvicornWorker", "--bind", "0.0.0.0:8000"]
command: ["gunicorn", "main:app", "--workers", "5", "--worker-class", "uvicorn.workers.UvicornWorker", "--bind", "0.0.0.0:8000", "--timeout", "300"]
env:
- name: FORMAT
# Expected Values - png, jpeg, etc.
value: "${FORMAT}"
- name: ARG_TYPE
value: "${ARG_TYPE}"
- name: FILE_FORMAT
value: "${FILE_FORMAT}"
# This is a health check endpoint which one should specify
# for aistore to determine the health of the ETL container.
readinessProbe:
Expand Down
4 changes: 3 additions & 1 deletion transformers/face_detection/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ aiohttp
numpy
opencv-python
aiofiles
kaggle
kaggle==1.5.16
webdataset==0.2.86
Pillow==10.0.0
39 changes: 1 addition & 38 deletions transformers/tests/test_face_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from aistore.sdk.etl_const import ETL_COMM_HPULL, ETL_COMM_HPUSH, ETL_COMM_HREV
import cv2

# from aistore.sdk.etl_templates import FACE_DETECTION_TRANSFORMER
from aistore.sdk.etl_templates import FACE_DETECTION_TRANSFORMER
from tests.utils import git_test_mode_format_image_tag_test
from tests.base import TestBase

Expand All @@ -18,43 +18,6 @@
)
logger = logging.getLogger(__name__)

# TODO: move var to aistore.sdk.etl_templates after merge
FACE_DETECTION_TRANSFORMER = """
apiVersion: v1
kind: Pod
metadata:
name: transformer-face-detection
annotations:
communication_type: "{communication_type}://"
wait_timeout: 5m
spec:
containers:
- name: server
image: aistorage/transformer_face_detection:latest
imagePullPolicy: Always
ports:
- name: default
containerPort: 8000
command: ["gunicorn", "main:app", "--workers", "20", "--worker-class", "uvicorn.workers.UvicornWorker", "--bind", "0.0.0.0:8000"]
readinessProbe:
httpGet:
path: /health
port: default
env:
- name: FORMAT
value: "{format}"
- name: ARG_TYPE
value: "{arg_type}"
volumeMounts:
- name: ais
mountPath: /tmp/ais
volumes:
- name: ais
hostPath:
path: /tmp/ais
type: Directory
"""


class TestTransformers(TestBase):
def setUp(self):
Expand Down
39 changes: 1 addition & 38 deletions transformers/tests/test_face_detection_stress.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from datetime import datetime
from aistore.sdk.etl_const import ETL_COMM_HPULL, ETL_COMM_HPUSH, ETL_COMM_HREV

# from aistore.sdk.etl_templates import FACE_DETECTION_TRANSFORMER
from aistore.sdk.etl_templates import FACE_DETECTION_TRANSFORMER
from tests.base import TestBase
from tests.utils import git_test_mode_format_image_tag_test

Expand All @@ -19,43 +19,6 @@
)
logger = logging.getLogger(__name__)

# TODO: move var to aistore.sdk.etl_templates after merge
FACE_DETECTION_TRANSFORMER = """
apiVersion: v1
kind: Pod
metadata:
name: transformer-face-detection
annotations:
communication_type: "{communication_type}://"
wait_timeout: 5m
spec:
containers:
- name: server
image: aistorage/transformer_face_detection:latest
imagePullPolicy: Always
ports:
- name: default
containerPort: 8000
command: ["gunicorn", "main:app", "--workers", "20", "--worker-class", "uvicorn.workers.UvicornWorker", "--bind", "0.0.0.0:8000"]
readinessProbe:
httpGet:
path: /health
port: default
env:
- name: FORMAT
value: "{format}"
- name: ARG_TYPE
value: "{arg_type}"
volumeMounts:
- name: ais
mountPath: /tmp/ais
volumes:
- name: ais
hostPath:
path: /tmp/ais
type: Directory
"""


class TestFaceDetectionStress(TestBase):
def setUp(self):
Expand Down

0 comments on commit 1e23f04

Please sign in to comment.