Skip to content

Commit

Permalink
target size arg added into extract faces
Browse files Browse the repository at this point in the history
  • Loading branch information
serengil committed Nov 10, 2024
1 parent 24f57ec commit 88b5580
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 3 deletions.
23 changes: 20 additions & 3 deletions retinaface/RetinaFace.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import warnings
import logging
from typing import Union, Any, Optional, Dict
from typing import Union, Any, Optional, Dict, Tuple, List

# this has to be set before importing tf
os.environ["TF_USE_LEGACY_KERAS"] = "1"
Expand Down Expand Up @@ -220,7 +220,9 @@ def extract_faces(
align: bool = True,
allow_upscaling: bool = True,
expand_face_area: int = 0,
) -> list:
target_size: Optional[Tuple[int, int]] = None,
min_max_norm: bool = True,
) -> List[np.ndarray]:
"""
Extract detected and aligned faces
Args:
Expand All @@ -230,6 +232,13 @@ def extract_faces(
align (bool): enable or disable alignment
allow_upscaling (bool): allowing up-scaling
expand_face_area (int): expand detected facial area with a percentage
target_size (optional tuple): resize the image by padding it with black pixels
to fit the specified dimensions. default is None
min_max_norm (bool): set this to True if you want to normalize image in [0, 1].
this is only running when target_size is not none.
for instance, matplotlib expects inputs in this scale. (default is True)
Returns:
result (List[np.ndarray]): list of extracted faces
"""
resp = []

Expand Down Expand Up @@ -289,6 +298,14 @@ def extract_faces(
int(rotated_y1) : int(rotated_y2), int(rotated_x1) : int(rotated_x2)
]

resp.append(facial_img[:, :, ::-1])
if target_size is not None:
facial_img = postprocess.resize_image(
img=facial_img, target_size=target_size, min_max_norm=min_max_norm
)

# to rgb
facial_img = facial_img[:, :, ::-1]

resp.append(facial_img)

return resp
63 changes: 63 additions & 0 deletions retinaface/commons/postprocess.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
# built-in dependencies
import math
from typing import Union, Tuple

# 3rd party dependencies
import numpy as np
from PIL import Image
import cv2
import tensorflow as tf

tf_major_version = int(tf.__version__.split(".", maxsplit=1)[0])
if tf_major_version == 1:
from keras.preprocessing import image
else:
from tensorflow.keras.preprocessing import image


# pylint: disable=unused-argument
Expand Down Expand Up @@ -143,6 +154,58 @@ def rotate_facial_area(
return (x1, y1, x2, y2)


def resize_image(
img: np.ndarray, target_size: Tuple[int, int], min_max_norm: bool = True
) -> np.ndarray:
"""
Resize an image to expected size of a ml model with adding black pixels.
Ref: github.com/serengil/deepface/blob/master/deepface/modules/preprocessing.py
Args:
img (np.ndarray): pre-loaded image as numpy array
target_size (tuple): input shape of ml model
min_max_norm (bool): set this to True if you want to normalize image in [0, 1].
this is only running when target_size is not none.
for instance, matplotlib expects inputs in this scale. (default is True)
Returns:
img (np.ndarray): resized input image
"""
factor_0 = target_size[0] / img.shape[0]
factor_1 = target_size[1] / img.shape[1]
factor = min(factor_0, factor_1)

dsize = (
int(img.shape[1] * factor),
int(img.shape[0] * factor),
)
img = cv2.resize(img, dsize)

diff_0 = target_size[0] - img.shape[0]
diff_1 = target_size[1] - img.shape[1]

# Put the base image in the middle of the padded image
img = np.pad(
img,
(
(diff_0 // 2, diff_0 - diff_0 // 2),
(diff_1 // 2, diff_1 - diff_1 // 2),
(0, 0),
),
"constant",
)

# double check: if target image is not still the same size with target.
if img.shape[0:2] != target_size:
img = cv2.resize(img, target_size)

# make it 4-dimensional how ML models expect
img = image.img_to_array(img)

if min_max_norm is True and img.max() > 1:
img = (img.astype(np.float32) / 255.0).astype(np.float32)

return img


def bbox_pred(boxes, box_deltas):
"""
This function is copied from the following code snippet:
Expand Down
10 changes: 10 additions & 0 deletions tests/test_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,13 @@ def test_different_expanding_ratios():
plt.imshow(face)
plt.axis("off")
plt.show()


def test_resize():
faces = RetinaFace.extract_faces(img_path="tests/dataset/img11.jpg", target_size=(224, 224))
for face in faces:
assert face.shape == (224, 224, 3)
if do_plotting is True:
plt.imshow(face)
plt.show()
logger.info("✅ resize test done")

0 comments on commit 88b5580

Please sign in to comment.