Skip to content

1.Refactor the code of 'predict_from_yaml.py' 2.Added concat_crops function #663

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
292 changes: 216 additions & 76 deletions tools/infer/text/predict_from_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
import logging
import os
import sys
from time import time

import cv2
import numpy as np
import yaml
from addict import Dict
from PIL import Image
Expand All @@ -22,12 +25,14 @@
from mindspore import Tensor, get_context, set_auto_parallel_context, set_context
from mindspore.communication import get_group_size, get_rank, init

from deploy.py_infer.src.infer_args import str2bool # noqa
from mindocr.data import build_dataset
from mindocr.data.transforms import create_transforms, run_transforms
from mindocr.models import build_model
from mindocr.postprocess import build_postprocess
from mindocr.utils.visualize import draw_boxes, show_imgs
from tools.arg_parser import _merge_options, _parse_options
from tools.infer.text.utils import get_image_paths
from tools.modelarts_adapter.modelarts import modelarts_setup

__dir__ = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -155,21 +160,7 @@ def predict_single_step(cfg, save_res=True):
)

# 3.Build model
amp_level = cfg.system.get("amp_level_infer", "O0")
if get_context("device_target") == "GPU" and amp_level == "O3":
logger.warning(
"Model evaluation does not support amp_level O3 on GPU currently. "
"The program has switched to amp_level O2 automatically."
)
amp_level = "O2"
cfg.model.backbone.pretrained = False
if cfg.predict.ckpt_load_path is None:
logger.warning(
f"No ckpt is available for {cfg.model.task}, "
"please check your configuration of 'predict.ckpt_load_path' in the yaml file."
)
network = build_model(cfg.model, ckpt_load_path=cfg.predict.ckpt_load_path, amp_level=amp_level)
network.set_train(False)
network = build_model_from_config(cfg)

# 4.Build postprocessor for network output
postprocessor = build_postprocess(cfg.postprocess)
Expand Down Expand Up @@ -230,72 +221,220 @@ def predict_single_step(cfg, save_res=True):
return preds_list


def predict_system(args, det_cfg, rec_cfg):
"""Run predict for both det and rec task"""
# merge image_dir option in model config
det_cfg.predict.dataset.dataset_root = ""
det_cfg.predict.dataset.data_dir = args.image_dir
output_save_dir = det_cfg.predict.output_save_dir or "./output"

# get det result from predict
preds_list = predict_single_step(det_cfg, save_res=False)

# set amp level
amp_level = det_cfg.system.get("amp_level_infer", "O0")
def build_model_from_config(cfg):
amp_level = cfg.system.get("amp_level_infer", "O0")
if get_context("device_target") == "GPU" and amp_level == "O3":
logger.warning(
"Model evaluation does not support amp_level O3 on GPU currently. "
"The program has switched to amp_level O2 automatically."
)
amp_level = "O2"

# create preprocess and postprocess for rec task
transforms = create_transforms(rec_cfg.predict.dataset.transform_pipeline)
postprocessor = build_postprocess(rec_cfg.postprocess)

# build rec model from yaml
rec_cfg.model.backbone.pretrained = False
if rec_cfg.predict.ckpt_load_path is None:
cfg.model.backbone.pretrained = False
if cfg.predict.ckpt_load_path is None:
logger.warning(
f"No ckpt is available for {rec_cfg.model.type}, "
f"No ckpt is available for {cfg.model.task}, "
"please check your configuration of 'predict.ckpt_load_path' in the yaml file."
)
rec_network = build_model(rec_cfg.model, ckpt_load_path=rec_cfg.predict.ckpt_load_path, amp_level=amp_level)

# start rec task
logger.info("Start rec")
img_list = [] # list of img_path
boxes_all = [] # list of boxes of all image
text_scores_all = [] # list of text and scores of all image
for preds_batch in tqdm(preds_list):
# preds_batch is a dictionary of det prediction output, which contains det information of a batch
preds_batch["texts"] = []
preds_batch["confs"] = []
for i, crops in enumerate(preds_batch["crops"]):
# A batch may contain multiple images
img_path = preds_batch["img_path"][i]
img_box = []
img_text_scores = []
for j, crop in enumerate(crops):
# For each image, it may contain several crops
data = {"image": crop}
data["image_ori"] = crop.copy()
data["image_shape"] = crop.shape
data = run_transforms(data, transforms[1:])
data = rec_network(Tensor(data["image"]).expand_dims(0))
out = postprocessor(data)
confs = out["confs"][0]
if confs > 0.5:
# Keep text with a confidence greater than 0.5
box = preds_batch["polys"][i][j]
text = out["texts"][0]
img_box.append(box)
img_text_scores.append((text, confs))
# Each image saves its path, box and texts_scores
img_list.append(img_path)
boxes_all.append(img_box)
text_scores_all.append(img_text_scores)
save_res(boxes_all, text_scores_all, img_list, save_path=os.path.join(output_save_dir, "system_results.txt"))
network = build_model(cfg.model, ckpt_load_path=cfg.predict.ckpt_load_path, amp_level=amp_level)
network.set_train(False)
return network


def sort_polys(polys):
return sorted(polys, key=lambda points: (points[0][1], points[0][0]))


def concat_crops(crops: list):
max_height = max(crop.shape[0] for crop in crops)
resized_crops = []
for crop in crops:
h, w, c = crop.shape
new_h = max_height
new_w = int((w / h) * new_h)

resized_img = cv2.resize(crop, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
resized_crops.append(resized_img)
crops = np.concatenate(resized_crops, axis=1)
return crops


class Predict_System:
def __init__(self, det_cfg, rec_cfg, is_concat=False):
for transform in det_cfg.predict.dataset.transform_pipeline:
if "DecodeImage" in transform:
transform["DecodeImage"].update({"keep_ori": True})
break
self.det_transforms = create_transforms(det_cfg.predict.dataset.transform_pipeline)
self.det_model = build_model_from_config(det_cfg)
self.det_postprocess = build_postprocess(det_cfg.postprocess)

self.rec_batch_size = rec_cfg.predict.loader.batch_size
self.rec_preprocess = create_transforms(rec_cfg.predict.dataset.transform_pipeline)
self.rec_model = build_model_from_config(rec_cfg)
self.rec_postprocess = build_postprocess(rec_cfg.postprocess)

self.is_concat = is_concat

def predict_rec(self, crops: list):
"""
Run text recognition serially for input images

Args:
img_or_path_list: list of str for img path or np.array for RGB image
do_visualize: visualize preprocess and final result and save them

Return:
rec_res: list of tuple, where each tuple is (text, score) - text recognition result for each input image
in order.
where text is the predicted text string, score is its confidence score.
e.g. [('apple', 0.9), ('bike', 1.0)]
"""
rec_res = []
num_crops = len(crops)

for idx in range(0, num_crops, self.rec_batch_size): # batch begin index i
batch_begin = idx
batch_end = min(idx + self.rec_batch_size, num_crops)
logger.info(f"Rec img idx range: [{batch_begin}, {batch_end})")
# TODO: set max_wh_ratio to the maximum wh ratio of images in the batch. and update it for resize,
# which may improve recognition accuracy in batch-mode
# especially for long text image. max_wh_ratio=max(max_wh_ratio, img_w / img_h).
# The short ones should be scaled with a.r. unchanged and padded to max width in batch.

# preprocess
# TODO: run in parallel with multiprocessing
img_batch = []
for j in range(batch_begin, batch_end): # image index j
data = run_transforms({"image": crops[j]}, self.rec_preprocess[1:])
img_batch.append(data["image"])

img_batch = np.stack(img_batch) if len(img_batch) > 1 else np.expand_dims(img_batch[0], axis=0)

# infer
net_pred = self.rec_model(Tensor(img_batch))

# postprocess
batch_res = self.rec_postprocess(net_pred)
rec_res.extend(list(zip(batch_res["texts"], batch_res["confs"])))

return rec_res

def predict(self, img_path):
"""
Detect and recognize texts in an image

Args:
img_or_path (str or np.ndarray): path to image or image rgb values as a numpy array

Return:
boxes (list): detected text boxes, in shape [num_boxes, num_points, 2], where the point coordinate (x, y)
follows: x - horizontal (image width direction), y - vertical (image height)
texts (list[tuple]): list of (text, score) where text is the recognized text string for each box,
and score is the confidence score.
time_profile (dict): record the time cost for each sub-task.
"""

time_profile = {}
start = time()

# detect text regions on an image
data = {"img_path": img_path}
data = run_transforms(data, self.det_transforms)
input_np = np.expand_dims(data["image"], axis=0)
logits = self.det_model(Tensor(input_np))
pred = self.det_postprocess(logits, shape_list=np.expand_dims(data["shape_list"], axis=0))
polys = pred["polys"][0]
scores = pred["scores"][0]
pred = dict(polys=polys, scores=scores)
det_res = validate_det_res(pred, data["image_ori"].shape[:2], min_poly_points=3, min_area=3)
det_res["img_ori"] = data["image_ori"]

time_profile["det"] = time() - start
polys = det_res["polys"].copy()
if len(polys) == 0:
logger.warning(f"No text detected in {img_path}")
time_profile["rec"] = 0.0
time_profile["all"] = time_profile["det"]
return [], [], time_profile
polys = sort_polys(polys)
logger.info(f"Num detected text boxes: {len(polys)}\nDet time: {time_profile['det']}")
if self.is_concat:
logger.info("After concatenating, 1 croped image will be recognized.")

# crop text regions
crops = []
for i in range(len(polys)):
poly = polys[i].astype(np.float32)
cropped_img = crop_text_region(data["image_ori"], poly, box_type=det_cfg.postprocess.box_type)
crops.append(cropped_img)

# if self.save_crop_res:
# cv2.imwrite(os.path.join(self.crop_res_save_dir, f"{fn}_crop_{i}.jpg"), cropped_img)
# show_imgs(crops, is_bgr_img=False)

# recognize cropped images
rs = time()
if self.is_concat:
crops = [concat_crops(crops)]
rec_res_all_crops = self.predict_rec(crops)
time_profile["rec"] = time() - rs

logger.info(
"Recognized texts: \n"
+ "\n".join([f"{text}\t{score}" for text, score in rec_res_all_crops])
+ f"\nRec time: {time_profile['rec']}"
)

# filter out low-score texts and merge detection and recognition results
boxes, text_scores = [], []
for i in range(len(polys)):
box = det_res["polys"][i]
if self.is_concat:
text = rec_res_all_crops[0][0]
text_score = rec_res_all_crops[0][1]
else:
text = rec_res_all_crops[i][0]
text_score = rec_res_all_crops[i][1]

if text_score >= 0.5:
boxes.append(box)
text_scores.append((text, text_score))
time_profile["all"] = time() - start
return boxes, text_scores, time_profile


def predict_both_step(args, det_cfg, rec_cfg):
# parse args
set_logger(name="mindocr")
pred_sys = Predict_System(det_cfg=det_cfg, rec_cfg=rec_cfg, is_concat=args.is_concat)
output_save_dir = det_cfg.predict.output_save_dir or "./output"
img_paths = get_image_paths(args.image_dir)

set_context(mode=det_cfg.system.mode)

tot_time = {} # {'det': 0, 'rec': 0, 'all': 0}
boxes_all, text_scores_all = [], []
for i, img_path in enumerate(img_paths):
logger.info(f"Infering [{i+1}/{len(img_paths)}]: {img_path}")
boxes, text_scores, time_prof = pred_sys.predict(img_path)
boxes_all.append(boxes)
text_scores_all.append(text_scores)

for k in time_prof:
if k not in tot_time:
tot_time[k] = time_prof[k]
else:
tot_time[k] += time_prof[k]

fps = len(img_paths) / tot_time["all"]
logger.info(f"Total time:{tot_time['all']}")
logger.info(f"Average FPS: {fps}")
avg_time = {k: tot_time[k] / len(img_paths) for k in tot_time}
logger.info(f"Averge time cost: {avg_time}")

# save result
save_res(boxes_all, text_scores_all, img_paths, save_path=os.path.join(output_save_dir, "system_results.txt"))
logger.info(f"Done! Results saved in {os.path.join(output_save_dir, 'system_results.txt')}")


def create_parser():
Expand All @@ -314,6 +453,7 @@ def create_parser():
default="configs/rec/crnn/crnn_resnet34.yaml",
help='YAML config file specifying default arguments for rec (default="configs/rec/crnn/crnn_resnet34.yaml")',
)
parser.add_argument("--is_concat", type=str2bool, default=False, help="image path or image directory")
parser.add_argument(
"-o",
"--opt",
Expand All @@ -323,7 +463,9 @@ def create_parser():
)
# modelarts
group = parser.add_argument_group("modelarts")
group.add_argument("--enable_modelarts", type=bool, default=False, help="Run on modelarts platform (default=False)")
group.add_argument(
"--enable_modelarts", type=str2bool, default=False, help="Run on modelarts platform (default=False)"
)
group.add_argument(
"--device_target",
type=str,
Expand All @@ -337,8 +479,6 @@ def create_parser():
group.add_argument("--pretrain_url", type=str, default="", help="pre_train_model paths in obs")
group.add_argument("--train_url", type=str, default="", help="model folder to save/load")

# args = parser.parse_args()

return parser


Expand Down Expand Up @@ -378,4 +518,4 @@ def parse_args_and_config():
elif args.task_mode == "system":
rec_cfg = Dict(rec_cfg)
det_cfg = Dict(det_cfg)
predict_system(args, det_cfg, rec_cfg)
predict_both_step(args, det_cfg, rec_cfg)