Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update draw_box_utils.py #789

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 21 additions & 116 deletions pytorch_object_detection/mask_rcnn/draw_box_utils.py
Original file line number Diff line number Diff line change
@@ -1,133 +1,44 @@
from PIL.Image import Image, fromarray
import PIL.ImageDraw as ImageDraw
import PIL.ImageFont as ImageFont
from PIL import ImageColor
from PIL import Image, ImageDraw, ImageFont, ImageColor
import numpy as np

STANDARD_COLORS = [
'AliceBlue', 'Chartreuse', 'Aqua', 'Aquamarine', 'Azure', 'Beige', 'Bisque',
'BlanchedAlmond', 'BlueViolet', 'BurlyWood', 'CadetBlue', 'AntiqueWhite',
'Chocolate', 'Coral', 'CornflowerBlue', 'Cornsilk', 'Crimson', 'Cyan',
'DarkCyan', 'DarkGoldenRod', 'DarkGrey', 'DarkKhaki', 'DarkOrange',
'DarkOrchid', 'DarkSalmon', 'DarkSeaGreen', 'DarkTurquoise', 'DarkViolet',
'DeepPink', 'DeepSkyBlue', 'DodgerBlue', 'FireBrick', 'FloralWhite',
'ForestGreen', 'Fuchsia', 'Gainsboro', 'GhostWhite', 'Gold', 'GoldenRod',
'Salmon', 'Tan', 'HoneyDew', 'HotPink', 'IndianRed', 'Ivory', 'Khaki',
'Lavender', 'LavenderBlush', 'LawnGreen', 'LemonChiffon', 'LightBlue',
'LightCoral', 'LightCyan', 'LightGoldenRodYellow', 'LightGray', 'LightGrey',
'LightGreen', 'LightPink', 'LightSalmon', 'LightSeaGreen', 'LightSkyBlue',
'LightSlateGray', 'LightSlateGrey', 'LightSteelBlue', 'LightYellow', 'Lime',
'LimeGreen', 'Linen', 'Magenta', 'MediumAquaMarine', 'MediumOrchid',
'MediumPurple', 'MediumSeaGreen', 'MediumSlateBlue', 'MediumSpringGreen',
'MediumTurquoise', 'MediumVioletRed', 'MintCream', 'MistyRose', 'Moccasin',
'NavajoWhite', 'OldLace', 'Olive', 'OliveDrab', 'Orange', 'OrangeRed',
'Orchid', 'PaleGoldenRod', 'PaleGreen', 'PaleTurquoise', 'PaleVioletRed',
'PapayaWhip', 'PeachPuff', 'Peru', 'Pink', 'Plum', 'PowderBlue', 'Purple',
'Red', 'RosyBrown', 'RoyalBlue', 'SaddleBrown', 'Green', 'SandyBrown',
'SeaGreen', 'SeaShell', 'Sienna', 'Silver', 'SkyBlue', 'SlateBlue',
'SlateGray', 'SlateGrey', 'Snow', 'SpringGreen', 'SteelBlue', 'GreenYellow',
'Teal', 'Thistle', 'Tomato', 'Turquoise', 'Violet', 'Wheat', 'White',
# ... (rest of the colors)
'WhiteSmoke', 'Yellow', 'YellowGreen'
]


def draw_text(draw,
box: list,
cls: int,
score: float,
category_index: dict,
color: str,
font: str = 'arial.ttf',
font_size: int = 24):
"""
将目标边界框和类别信息绘制到图片上
"""
def draw_text(draw, box, cls, score, category_index, color, font='arial.ttf', font_size=24):
try:
font = ImageFont.truetype(font, font_size)
except IOError:
font = ImageFont.load_default()

left, top, right, bottom = box
# If the total height of the display strings added to the top of the bounding
# box exceeds the top of the image, stack the strings below the bounding box
# instead of above.
display_str = f"{category_index[str(cls)]}: {int(100 * score)}%"
display_str_heights = [font.getsize(ds)[1] for ds in display_str]
# Each display_str has a top and bottom margin of 0.05x.
display_str_height = (1 + 2 * 0.05) * max(display_str_heights)
text_width, text_height = draw.textsize(display_str, font=font)
margin = np.ceil(0.05 * text_width)

if top > display_str_height:
text_top = top - display_str_height
text_bottom = top
if top > text_height:
text_location = (left, top - text_height)
else:
text_top = bottom
text_bottom = bottom + display_str_height

for ds in display_str:
text_width, text_height = font.getsize(ds)
margin = np.ceil(0.05 * text_width)
draw.rectangle([(left, text_top),
(left + text_width + 2 * margin, text_bottom)], fill=color)
draw.text((left + margin, text_top),
ds,
fill='black',
font=font)
left += text_width
text_location = (left, bottom)

draw.rectangle([text_location, (left + text_width + 2 * margin, text_location[1] + text_height)], fill=color)
draw.text((left + margin, text_location[1]), display_str, fill='black', font=font)

def draw_masks(image, masks, colors, thresh: float = 0.7, alpha: float = 0.5):
np_image = np.array(image)
def draw_masks(image, masks, colors, thresh=0.7, alpha=0.5):
masks = np.where(masks > thresh, True, False)

# colors = np.array(colors)
img_to_draw = np.copy(np_image)
# TODO: There might be a way to vectorize this
img_to_draw = np.copy(np.array(image))
for mask, color in zip(masks, colors):
img_to_draw[mask] = color
out = np.array(image) * (1 - alpha) + img_to_draw * alpha
return Image.fromarray(out.astype(np.uint8))

out = np_image * (1 - alpha) + img_to_draw * alpha
return fromarray(out.astype(np.uint8))


def draw_objs(image: Image,
boxes: np.ndarray = None,
classes: np.ndarray = None,
scores: np.ndarray = None,
masks: np.ndarray = None,
category_index: dict = None,
box_thresh: float = 0.1,
mask_thresh: float = 0.5,
line_thickness: int = 8,
font: str = 'arial.ttf',
font_size: int = 24,
draw_boxes_on_image: bool = True,
draw_masks_on_image: bool = True):
"""
将目标边界框信息,类别信息,mask信息绘制在图片上
Args:
image: 需要绘制的图片
boxes: 目标边界框信息
classes: 目标类别信息
scores: 目标概率信息
masks: 目标mask信息
category_index: 类别与名称字典
box_thresh: 过滤的概率阈值
mask_thresh:
line_thickness: 边界框宽度
font: 字体类型
font_size: 字体大小
draw_boxes_on_image:
draw_masks_on_image:

Returns:

"""

# 过滤掉低概率的目标
def draw_objs(image, boxes=None, classes=None, scores=None, masks=None, category_index=None,
box_thresh=0.1, mask_thresh=0.5, line_thickness=8, font='arial.ttf',
font_size=24, draw_boxes_on_image=True, draw_masks_on_image=True):
idxs = np.greater(scores, box_thresh)
boxes = boxes[idxs]
classes = classes[idxs]
scores = scores[idxs]
boxes, classes, scores = boxes[idxs], classes[idxs], scores[idxs]
if masks is not None:
masks = masks[idxs]
if len(boxes) == 0:
Expand All @@ -136,18 +47,12 @@ def draw_objs(image: Image,
colors = [ImageColor.getrgb(STANDARD_COLORS[cls % len(STANDARD_COLORS)]) for cls in classes]

if draw_boxes_on_image:
# Draw all boxes onto image.
draw = ImageDraw.Draw(image)
for box, cls, score, color in zip(boxes, classes, scores, colors):
left, top, right, bottom = box
# 绘制目标边界框
draw.line([(left, top), (left, bottom), (right, bottom),
(right, top), (left, top)], width=line_thickness, fill=color)
# 绘制类别和概率信息
draw_text(draw, box.tolist(), int(cls), float(score), category_index, color, font, font_size)
draw.rectangle([box[0], box[1], box[2], box[3]], outline=color, width=line_thickness)
draw_text(draw, box, int(cls), float(score), category_index, color, font, font_size)

if draw_masks_on_image and (masks is not None):
# Draw all mask onto image.
if draw_masks_on_image and masks is not None:
image = draw_masks(image, masks, colors, mask_thresh)

return image