diff --git a/tools/infer/text/predict_from_yaml.py b/tools/infer/text/predict_from_yaml.py index b6d7ccb17..cab1c1fed 100644 --- a/tools/infer/text/predict_from_yaml.py +++ b/tools/infer/text/predict_from_yaml.py @@ -9,7 +9,10 @@ import logging import os import sys +from time import time +import cv2 +import numpy as np import yaml from addict import Dict from PIL import Image @@ -22,12 +25,14 @@ from mindspore import Tensor, get_context, set_auto_parallel_context, set_context from mindspore.communication import get_group_size, get_rank, init +from deploy.py_infer.src.infer_args import str2bool # noqa from mindocr.data import build_dataset from mindocr.data.transforms import create_transforms, run_transforms from mindocr.models import build_model from mindocr.postprocess import build_postprocess from mindocr.utils.visualize import draw_boxes, show_imgs from tools.arg_parser import _merge_options, _parse_options +from tools.infer.text.utils import get_image_paths from tools.modelarts_adapter.modelarts import modelarts_setup __dir__ = os.path.dirname(os.path.abspath(__file__)) @@ -155,21 +160,7 @@ def predict_single_step(cfg, save_res=True): ) # 3.Build model - amp_level = cfg.system.get("amp_level_infer", "O0") - if get_context("device_target") == "GPU" and amp_level == "O3": - logger.warning( - "Model evaluation does not support amp_level O3 on GPU currently. " - "The program has switched to amp_level O2 automatically." - ) - amp_level = "O2" - cfg.model.backbone.pretrained = False - if cfg.predict.ckpt_load_path is None: - logger.warning( - f"No ckpt is available for {cfg.model.task}, " - "please check your configuration of 'predict.ckpt_load_path' in the yaml file." - ) - network = build_model(cfg.model, ckpt_load_path=cfg.predict.ckpt_load_path, amp_level=amp_level) - network.set_train(False) + network = build_model_from_config(cfg) # 4.Build postprocessor for network output postprocessor = build_postprocess(cfg.postprocess) @@ -230,72 +221,220 @@ def predict_single_step(cfg, save_res=True): return preds_list -def predict_system(args, det_cfg, rec_cfg): - """Run predict for both det and rec task""" - # merge image_dir option in model config - det_cfg.predict.dataset.dataset_root = "" - det_cfg.predict.dataset.data_dir = args.image_dir - output_save_dir = det_cfg.predict.output_save_dir or "./output" - - # get det result from predict - preds_list = predict_single_step(det_cfg, save_res=False) - - # set amp level - amp_level = det_cfg.system.get("amp_level_infer", "O0") +def build_model_from_config(cfg): + amp_level = cfg.system.get("amp_level_infer", "O0") if get_context("device_target") == "GPU" and amp_level == "O3": logger.warning( "Model evaluation does not support amp_level O3 on GPU currently. " "The program has switched to amp_level O2 automatically." ) amp_level = "O2" - - # create preprocess and postprocess for rec task - transforms = create_transforms(rec_cfg.predict.dataset.transform_pipeline) - postprocessor = build_postprocess(rec_cfg.postprocess) - - # build rec model from yaml - rec_cfg.model.backbone.pretrained = False - if rec_cfg.predict.ckpt_load_path is None: + cfg.model.backbone.pretrained = False + if cfg.predict.ckpt_load_path is None: logger.warning( - f"No ckpt is available for {rec_cfg.model.type}, " + f"No ckpt is available for {cfg.model.task}, " "please check your configuration of 'predict.ckpt_load_path' in the yaml file." ) - rec_network = build_model(rec_cfg.model, ckpt_load_path=rec_cfg.predict.ckpt_load_path, amp_level=amp_level) - - # start rec task - logger.info("Start rec") - img_list = [] # list of img_path - boxes_all = [] # list of boxes of all image - text_scores_all = [] # list of text and scores of all image - for preds_batch in tqdm(preds_list): - # preds_batch is a dictionary of det prediction output, which contains det information of a batch - preds_batch["texts"] = [] - preds_batch["confs"] = [] - for i, crops in enumerate(preds_batch["crops"]): - # A batch may contain multiple images - img_path = preds_batch["img_path"][i] - img_box = [] - img_text_scores = [] - for j, crop in enumerate(crops): - # For each image, it may contain several crops - data = {"image": crop} - data["image_ori"] = crop.copy() - data["image_shape"] = crop.shape - data = run_transforms(data, transforms[1:]) - data = rec_network(Tensor(data["image"]).expand_dims(0)) - out = postprocessor(data) - confs = out["confs"][0] - if confs > 0.5: - # Keep text with a confidence greater than 0.5 - box = preds_batch["polys"][i][j] - text = out["texts"][0] - img_box.append(box) - img_text_scores.append((text, confs)) - # Each image saves its path, box and texts_scores - img_list.append(img_path) - boxes_all.append(img_box) - text_scores_all.append(img_text_scores) - save_res(boxes_all, text_scores_all, img_list, save_path=os.path.join(output_save_dir, "system_results.txt")) + network = build_model(cfg.model, ckpt_load_path=cfg.predict.ckpt_load_path, amp_level=amp_level) + network.set_train(False) + return network + + +def sort_polys(polys): + return sorted(polys, key=lambda points: (points[0][1], points[0][0])) + + +def concat_crops(crops: list): + max_height = max(crop.shape[0] for crop in crops) + resized_crops = [] + for crop in crops: + h, w, c = crop.shape + new_h = max_height + new_w = int((w / h) * new_h) + + resized_img = cv2.resize(crop, (new_w, new_h), interpolation=cv2.INTER_LINEAR) + resized_crops.append(resized_img) + crops = np.concatenate(resized_crops, axis=1) + return crops + + +class Predict_System: + def __init__(self, det_cfg, rec_cfg, is_concat=False): + for transform in det_cfg.predict.dataset.transform_pipeline: + if "DecodeImage" in transform: + transform["DecodeImage"].update({"keep_ori": True}) + break + self.det_transforms = create_transforms(det_cfg.predict.dataset.transform_pipeline) + self.det_model = build_model_from_config(det_cfg) + self.det_postprocess = build_postprocess(det_cfg.postprocess) + + self.rec_batch_size = rec_cfg.predict.loader.batch_size + self.rec_preprocess = create_transforms(rec_cfg.predict.dataset.transform_pipeline) + self.rec_model = build_model_from_config(rec_cfg) + self.rec_postprocess = build_postprocess(rec_cfg.postprocess) + + self.is_concat = is_concat + + def predict_rec(self, crops: list): + """ + Run text recognition serially for input images + + Args: + img_or_path_list: list of str for img path or np.array for RGB image + do_visualize: visualize preprocess and final result and save them + + Return: + rec_res: list of tuple, where each tuple is (text, score) - text recognition result for each input image + in order. + where text is the predicted text string, score is its confidence score. + e.g. [('apple', 0.9), ('bike', 1.0)] + """ + rec_res = [] + num_crops = len(crops) + + for idx in range(0, num_crops, self.rec_batch_size): # batch begin index i + batch_begin = idx + batch_end = min(idx + self.rec_batch_size, num_crops) + logger.info(f"Rec img idx range: [{batch_begin}, {batch_end})") + # TODO: set max_wh_ratio to the maximum wh ratio of images in the batch. and update it for resize, + # which may improve recognition accuracy in batch-mode + # especially for long text image. max_wh_ratio=max(max_wh_ratio, img_w / img_h). + # The short ones should be scaled with a.r. unchanged and padded to max width in batch. + + # preprocess + # TODO: run in parallel with multiprocessing + img_batch = [] + for j in range(batch_begin, batch_end): # image index j + data = run_transforms({"image": crops[j]}, self.rec_preprocess[1:]) + img_batch.append(data["image"]) + + img_batch = np.stack(img_batch) if len(img_batch) > 1 else np.expand_dims(img_batch[0], axis=0) + + # infer + net_pred = self.rec_model(Tensor(img_batch)) + + # postprocess + batch_res = self.rec_postprocess(net_pred) + rec_res.extend(list(zip(batch_res["texts"], batch_res["confs"]))) + + return rec_res + + def predict(self, img_path): + """ + Detect and recognize texts in an image + + Args: + img_or_path (str or np.ndarray): path to image or image rgb values as a numpy array + + Return: + boxes (list): detected text boxes, in shape [num_boxes, num_points, 2], where the point coordinate (x, y) + follows: x - horizontal (image width direction), y - vertical (image height) + texts (list[tuple]): list of (text, score) where text is the recognized text string for each box, + and score is the confidence score. + time_profile (dict): record the time cost for each sub-task. + """ + + time_profile = {} + start = time() + + # detect text regions on an image + data = {"img_path": img_path} + data = run_transforms(data, self.det_transforms) + input_np = np.expand_dims(data["image"], axis=0) + logits = self.det_model(Tensor(input_np)) + pred = self.det_postprocess(logits, shape_list=np.expand_dims(data["shape_list"], axis=0)) + polys = pred["polys"][0] + scores = pred["scores"][0] + pred = dict(polys=polys, scores=scores) + det_res = validate_det_res(pred, data["image_ori"].shape[:2], min_poly_points=3, min_area=3) + det_res["img_ori"] = data["image_ori"] + + time_profile["det"] = time() - start + polys = det_res["polys"].copy() + if len(polys) == 0: + logger.warning(f"No text detected in {img_path}") + time_profile["rec"] = 0.0 + time_profile["all"] = time_profile["det"] + return [], [], time_profile + polys = sort_polys(polys) + logger.info(f"Num detected text boxes: {len(polys)}\nDet time: {time_profile['det']}") + if self.is_concat: + logger.info("After concatenating, 1 croped image will be recognized.") + + # crop text regions + crops = [] + for i in range(len(polys)): + poly = polys[i].astype(np.float32) + cropped_img = crop_text_region(data["image_ori"], poly, box_type=det_cfg.postprocess.box_type) + crops.append(cropped_img) + + # if self.save_crop_res: + # cv2.imwrite(os.path.join(self.crop_res_save_dir, f"{fn}_crop_{i}.jpg"), cropped_img) + # show_imgs(crops, is_bgr_img=False) + + # recognize cropped images + rs = time() + if self.is_concat: + crops = [concat_crops(crops)] + rec_res_all_crops = self.predict_rec(crops) + time_profile["rec"] = time() - rs + + logger.info( + "Recognized texts: \n" + + "\n".join([f"{text}\t{score}" for text, score in rec_res_all_crops]) + + f"\nRec time: {time_profile['rec']}" + ) + + # filter out low-score texts and merge detection and recognition results + boxes, text_scores = [], [] + for i in range(len(polys)): + box = det_res["polys"][i] + if self.is_concat: + text = rec_res_all_crops[0][0] + text_score = rec_res_all_crops[0][1] + else: + text = rec_res_all_crops[i][0] + text_score = rec_res_all_crops[i][1] + + if text_score >= 0.5: + boxes.append(box) + text_scores.append((text, text_score)) + time_profile["all"] = time() - start + return boxes, text_scores, time_profile + + +def predict_both_step(args, det_cfg, rec_cfg): + # parse args + set_logger(name="mindocr") + pred_sys = Predict_System(det_cfg=det_cfg, rec_cfg=rec_cfg, is_concat=args.is_concat) + output_save_dir = det_cfg.predict.output_save_dir or "./output" + img_paths = get_image_paths(args.image_dir) + + set_context(mode=det_cfg.system.mode) + + tot_time = {} # {'det': 0, 'rec': 0, 'all': 0} + boxes_all, text_scores_all = [], [] + for i, img_path in enumerate(img_paths): + logger.info(f"Infering [{i+1}/{len(img_paths)}]: {img_path}") + boxes, text_scores, time_prof = pred_sys.predict(img_path) + boxes_all.append(boxes) + text_scores_all.append(text_scores) + + for k in time_prof: + if k not in tot_time: + tot_time[k] = time_prof[k] + else: + tot_time[k] += time_prof[k] + + fps = len(img_paths) / tot_time["all"] + logger.info(f"Total time:{tot_time['all']}") + logger.info(f"Average FPS: {fps}") + avg_time = {k: tot_time[k] / len(img_paths) for k in tot_time} + logger.info(f"Averge time cost: {avg_time}") + + # save result + save_res(boxes_all, text_scores_all, img_paths, save_path=os.path.join(output_save_dir, "system_results.txt")) + logger.info(f"Done! Results saved in {os.path.join(output_save_dir, 'system_results.txt')}") def create_parser(): @@ -314,6 +453,7 @@ def create_parser(): default="configs/rec/crnn/crnn_resnet34.yaml", help='YAML config file specifying default arguments for rec (default="configs/rec/crnn/crnn_resnet34.yaml")', ) + parser.add_argument("--is_concat", type=str2bool, default=False, help="image path or image directory") parser.add_argument( "-o", "--opt", @@ -323,7 +463,9 @@ def create_parser(): ) # modelarts group = parser.add_argument_group("modelarts") - group.add_argument("--enable_modelarts", type=bool, default=False, help="Run on modelarts platform (default=False)") + group.add_argument( + "--enable_modelarts", type=str2bool, default=False, help="Run on modelarts platform (default=False)" + ) group.add_argument( "--device_target", type=str, @@ -337,8 +479,6 @@ def create_parser(): group.add_argument("--pretrain_url", type=str, default="", help="pre_train_model paths in obs") group.add_argument("--train_url", type=str, default="", help="model folder to save/load") - # args = parser.parse_args() - return parser @@ -378,4 +518,4 @@ def parse_args_and_config(): elif args.task_mode == "system": rec_cfg = Dict(rec_cfg) det_cfg = Dict(det_cfg) - predict_system(args, det_cfg, rec_cfg) + predict_both_step(args, det_cfg, rec_cfg)