From 051bfda38cb9bc1dc5af7fa09eae97d516a30d91 Mon Sep 17 00:00:00 2001 From: HaoyangLI <417493727@qq.com> Date: Tue, 16 May 2023 12:02:56 +0800 Subject: [PATCH 1/8] add text direction classification model (mobilenetv3_small_100) add text direction classification model (mobilenetv3_small_100) add text direction classification model (mobilenetv3_small_100) add text direction classification model (mobilenetv3_small_100) add text direction classification model (mobilenetv3_small_100) add text direction classification model (mobilenetv3_small_100) add cls_mv3 update typo in clspostprocess update cls data converter cls run on grpah mode fix rebase change update cls data converters update config files tidy up code tidy up code update cls model class name update cls postprocess and metric rm cls data converter, update in another PR fix CI failure tidy up code fix bug update yaml --- mindocr/models/__init__.py | 1 + mindocr/utils/callbacks.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/mindocr/models/__init__.py b/mindocr/models/__init__.py index ae01ae83e..fd0f7e63f 100644 --- a/mindocr/models/__init__.py +++ b/mindocr/models/__init__.py @@ -8,6 +8,7 @@ from .rec_crnn import * from .rec_rare import * from .rec_svtr import * +from .cls_mv3 import * __all__ = [] __all__.extend(builder.__all__) diff --git a/mindocr/utils/callbacks.py b/mindocr/utils/callbacks.py index 34f1d870a..dff2dcc74 100644 --- a/mindocr/utils/callbacks.py +++ b/mindocr/utils/callbacks.py @@ -20,7 +20,6 @@ from mindspore import jit else: from mindspore import ms_function - jit = ms_function From 50b8034557aaae20b053f9edca8b59885c4a5e81 Mon Sep 17 00:00:00 2001 From: HaoyangLI <417493727@qq.com> Date: Tue, 30 May 2023 17:17:58 +0800 Subject: [PATCH 2/8] add predict_cls, predict_system add predict_cls add predict_cls update online predict update online prediction fix predict_cls bug, chw->hwc format code format code2 --- mindocr/models/__init__.py | 1 - tools/infer/text/config.py | 34 ++++ tools/infer/text/postprocess.py | 9 +- tools/infer/text/predict_cls.py | 264 +++++++++++++++++++++++++++++ tools/infer/text/predict_det.py | 3 +- tools/infer/text/predict_rec.py | 7 +- tools/infer/text/predict_system.py | 17 +- tools/infer/text/preprocess.py | 33 ++-- 8 files changed, 343 insertions(+), 25 deletions(-) create mode 100644 tools/infer/text/predict_cls.py diff --git a/mindocr/models/__init__.py b/mindocr/models/__init__.py index fd0f7e63f..ae01ae83e 100644 --- a/mindocr/models/__init__.py +++ b/mindocr/models/__init__.py @@ -8,7 +8,6 @@ from .rec_crnn import * from .rec_rare import * from .rec_svtr import * -from .cls_mv3 import * __all__ = [] __all__.extend(builder.__all__) diff --git a/tools/infer/text/config.py b/tools/infer/text/config.py index 4076cd4f6..43c1f4109 100644 --- a/tools/infer/text/config.py +++ b/tools/infer/text/config.py @@ -37,6 +37,7 @@ def create_parser(): # parser.add_argument("--gpu_id", type=int, default=0) parser.add_argument("--det_model_config", type=str, help="path to det model yaml config") # added + parser.add_argument("--cls_model_config", type=str, help="path to cls model yaml config") # added parser.add_argument("--rec_model_config", type=str, help="path to rec model yaml config") # added # params for text detector @@ -90,6 +91,39 @@ def create_parser(): parser.add_argument("--use_dilation", type=str2bool, default=False) parser.add_argument("--det_db_score_mode", type=str, default="fast") + # params for text direction classification + parser.add_argument( + "--cls_model_dir", + type=str, + help="directory containing the text direction classification model checkpoint best.ckpt, " + "or path to a specific checkpoint file.", + ) # determine the network weights + parser.add_argument( + "--cls_batch_mode", + type=str2bool, + default=True, + help="Whether to run text direction classification inference in batch-mode, " + "which is faster but may degrade the accraucy due to padding or resizing to the same shape.", + ) # added + parser.add_argument("--cls_batch_num", type=int, default=8) + parser.add_argument("--cls_algorithm", type=str, choices=["MV3"]) + parser.add_argument( + "--cls_rotate_thre", + type=float, + default=0.9, + help="Rotate the image when text direction classification score is larger than this threshold.", + ) + parser.add_argument( + "--cls_image_shape", + type=str, + default="3, 48, 192", + help="C, H, W for taget image shape. max_wh_ratio=W/H will be used to control the maximum width " + "after 'aspect-ratio-kept' resizing. Set W larger for longer text.", + ) + parser.add_argument( + "--cls_label_list", type=str, nargs="+", default=["0", "180"], choices=[["0", "180"], ["0", "90", "180", "270"]] + ) + # params for text recognizer parser.add_argument( "--rec_algorithm", diff --git a/tools/infer/text/postprocess.py b/tools/infer/text/postprocess.py index d2524c47f..552bed06a 100644 --- a/tools/infer/text/postprocess.py +++ b/tools/infer/text/postprocess.py @@ -35,7 +35,7 @@ def __init__(self, task="det", algo="DB", **kwargs): raise ValueError(f"No postprocess config defined for {algo}. Please check the algorithm name.") self.rescale_internally = True self.round = True - elif task == "rec": + elif task in ("rec", "cls"): # TODO: update character_dict_path and use_space_char after CRNN trained using en_dict.txt released if algo.startswith("CRNN") or algo.startswith("SVTR"): # TODO: allow users to input char dict path @@ -52,7 +52,10 @@ def __init__(self, task="det", algo="DB", **kwargs): character_dict_path=dict_path, use_space_char=False, ) - + elif algo.startswith("MV"): + postproc_cfg = dict( + name="ClsPostprocess", + ) else: raise ValueError(f"No postprocess config defined for {algo}. Please check the algorithm name.") @@ -108,6 +111,6 @@ def __call__(self, pred, data=None): det_res = dict(polys=polys, scores=scores) return det_res - elif self.task == "rec": + elif self.task in ("rec", "cls"): output = self.postprocess(pred) return output diff --git a/tools/infer/text/predict_cls.py b/tools/infer/text/predict_cls.py new file mode 100644 index 000000000..1602e5d95 --- /dev/null +++ b/tools/infer/text/predict_cls.py @@ -0,0 +1,264 @@ +""" +Text classification inference + +Example: + $ python tools/infer/text/predict_cls.py --image_dir {img_dir_or_img_path} --cls_algorithm MV3 +""" +import os +import sys +from time import time + +import numpy as np +from config import parse_args +from postprocess import Postprocessor +from preprocess import Preprocessor +from tqdm import tqdm +from utils import get_ckpt_file, get_image_paths + +import mindspore as ms + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../../"))) + +from mindocr import build_model # noqa +from mindocr.utils.visualize import show_imgs # noqa + +# map algorithm name to model name (which can be checked by `mindocr.list_models()`) +# NOTE: Modify it to add new model for inference. +algo_to_model_name = {"MV3": "cls_mobilenet_v3_small_100_model"} + + +class DirectionClassifier(object): + def __init__(self, args): + self.batch_num = args.cls_batch_num + self.batch_mode = args.cls_batch_mode + self.rotate_thre = args.cls_rotate_thre + self.visualize_output = args.visualize_output + # self.batch_mode = args.cls_batch_mode and (self.batch_num > 1) + print( + "INFO: recognize in {} mode {}".format( + "batch" if self.batch_mode else "serial", + "batch_size: " + str(self.batch_num) if self.batch_mode else "", + ) + ) + + # build model for algorithm with pretrained weights or local checkpoint + ckpt_dir = args.cls_model_dir + if ckpt_dir is None: + pretrained = True + ckpt_load_path = None + else: + ckpt_load_path = get_ckpt_file(ckpt_dir) + pretrained = False + assert args.cls_algorithm in algo_to_model_name, f"Invalid cls_algorithm {args.cls_algorithm}. " + f"Supported classification algorithms are {list(algo_to_model_name.keys())}." + + model_name = algo_to_model_name[args.cls_algorithm] + self.model = build_model(model_name, pretrained=pretrained, ckpt_load_path=ckpt_load_path) + self.model.set_train(False) + print( + "INFO: Init text direction classification model: {} --> {}. Model weights loaded from {}".format( + args.cls_algorithm, model_name, "pretrained url" if pretrained else ckpt_load_path + ) + ) + + # build preprocess and postprocess + # NOTE: most process hyper-params should be set optimally for the pick algo. + self.preprocess = Preprocessor( + task="cls", + algo=args.cls_algorithm, + cls_image_shape=args.cls_image_shape, + cls_batch_mode=self.batch_mode, + cls_batch_num=self.batch_num, + ) + + # TODO: try GeneratorDataset to wrap preprocess transform on batch for possible speed-up. + # if use_ms_dataset: ds = ms.dataset.GeneratorDataset(wrap_preprocess, ) in run_batchwise + + self.postprocess = Postprocessor(task="cls", algo=args.cls_algorithm, label_list=args.cls_label_list) + + self.vis_dir = args.draw_img_save_dir + os.makedirs(self.vis_dir, exist_ok=True) + + def __call__(self, img_or_path_list: list): + """ + Run text direction classification serially for input images + + Args: + img_or_path_list: list of str for img path or np.array for RGB image + + Return: + list of dict, each contains the follow keys for text direction classification result. + e.g. [{'texts': 'abc', 'confs': 0.9}, {'texts': 'cd', 'confs': 1.0}] + - texts: text string + - confs: prediction confidence + """ + + assert isinstance( + img_or_path_list, list + ), "Input for text direction classification must be list of images or image paths." + print("INFO: num images for cls: ", len(img_or_path_list)) + if self.batch_mode: + cls_res_all_crops, all_rotated_imgs = self.run_batchwise(img_or_path_list) + else: + cls_res_all_crops = [] + for i, img_or_path in enumerate(img_or_path_list): + cls_res = self.run_single(img_or_path, i) + cls_res_all_crops.append(cls_res) + + # TODO: 加vis和save功能 + + return cls_res_all_crops, all_rotated_imgs + + def run_batchwise(self, img_or_path_list: list): + """ + Run text direction classification serially for input images + + Args: + img_or_path_list: list of str for img path or np.array for RGB image + + Return: + cls_res: list of tuple, where each tuple is (text, score) - text direction classification result for + each input image in order. where text is the predicted text string, score is its confidence score. + e.g. [('apple', 0.9), ('bike', 1.0)] + """ + cls_res, all_rotated_imgs = [], [] + num_imgs = len(img_or_path_list) + + for idx in tqdm(range(0, num_imgs, self.batch_num)): # batch begin index i + batch_begin = idx + batch_end = min(idx + self.batch_num, num_imgs) + # print(f"Cls img idx range: [{batch_begin}, {batch_end})") + # TODO: set max_wh_ratio to the maximum wh ratio of images in the batch. and update it for resize, + # TODO: which may improve text direction classification accuracy in batch-mode + # TODO: especially for long text image. max_wh_ratio=max(max_wh_ratio, img_w / img_h). + # TODO: The short ones should be scaled with a.r. unchanged and padded to max width in batch. + + # preprocess + # TODO: run in parallel with multiprocessing + img_batch = [] + for j in range(batch_begin, batch_end): # image index j + data = self.preprocess(img_or_path_list[j]) + img_batch.append(data["image"]) # c,h,w + if self.visualize_output: + fn = os.path.basename(data.get("img_path", f"crop_{j}.png")).split(".")[0] + show_imgs( + [data["image"]], + title=fn + "_cls_preprocessed", + mean_rgb=[127.0, 127.0, 127.0], + std_rgb=[127.0, 127.0, 127.0], + is_chw=True, + show=False, + save_path=os.path.join(self.vis_dir, fn + "_cls_preproc.png"), + ) + + img_batch = np.stack(img_batch) if len(img_batch) > 1 else np.expand_dims(img_batch[0], axis=0) + # infer + net_pred = self.model(ms.Tensor(img_batch)) + + # postprocess + batch_res = self.postprocess(net_pred) + img_batch = self.rotate(img_batch, batch_res) + + cls_res.extend(list(zip(batch_res["angles"], batch_res["scores"]))) + all_rotated_imgs.extend(img_batch) + + return cls_res, all_rotated_imgs + + def rotate(self, img_batch, batch_res): + rotated_img_batch = [] + for i, score in enumerate(batch_res["scores"]): + tmp_img = img_batch[i] + if score > self.rotate_thre: + tmp_img = np.rot90(tmp_img, k=int(int(batch_res["angles"][i]) / 90)) + rotated_img_batch.append(tmp_img.transpose(1, 2, 0)) # c, h, w --> h, w, c for saving and visualization + return rotated_img_batch + + def run_single(self, img_or_path, crop_idx=0): + """ + Text direction classification inference on a single image + Args: + img_or_path: str for image path or np.array for image rgb value + + Return: + dict with keys: + - texts (str): preditive text string + - confs (int): confidence of the prediction + """ + # preprocess + data = self.preprocess(img_or_path) + + # visualize preprocess result + if self.visualize_output: + # show_imgs([data['image_ori']], is_bgr_img=False, title=f'origin_{i}') + fn = os.path.basename(data.get("img_path", f"crop_{crop_idx}.png")).split(".")[0] + show_imgs( + [data["image"]], + title=fn + "_cls_preprocessed", + mean_rgb=[127.0, 127.0, 127.0], + std_rgb=[127.0, 127.0, 127.0], + is_chw=True, + show=False, + save_path=os.path.join(self.vis_dir, fn + "_cls_preproc.png"), + ) + print("Origin image shape: ", data["image_ori"].shape) + print("Preprocessed image shape: ", data["image"].shape) + + # infer + input_np = data["image"] + if len(input_np.shape) == 3: + net_input = np.expand_dims(input_np, axis=0) + + net_output = self.model(ms.Tensor(net_input)) + + # postprocess + cls_res = self.postprocess(net_output) + + cls_res = (cls_res["angles"][0], cls_res["scores"][0]) + + print(f"Crop {crop_idx} cls result:", cls_res) + + return cls_res + + +def save_cls_res(cls_res_all, img_paths, include_score=False, save_path="./cls_results.txt"): + lines = [] + for i, cls_res in enumerate(cls_res_all): + if include_score: + img_pred = os.path.basename(img_paths[i]) + "\t" + str(list(cls_res)) + "\n" + else: + img_pred = os.path.basename(img_paths[i]) + "\t" + cls_res[0] + "\n" + lines.append(img_pred) + + with open(save_path, "w") as f: + f.writelines(lines) + + return lines + + +if __name__ == "__main__": + # parse args + args = parse_args() + save_dir = args.draw_img_save_dir + img_paths = get_image_paths(args.image_dir) + # uncomment it to quick test the infer FPS + # img_paths = img_paths[:250] + + ms.set_context(mode=args.mode) + ms.set_context(device_id=5) + + # init classifier + classifier = DirectionClassifier(args) + + # TODO: warmup + + # run for each image + start = time() + cls_res_all = classifier(img_paths) + t = time() - start + # save all results in a txt file + save_fp = os.path.join(save_dir, "cls_results.txt" if args.cls_batch_mode else "cls_results_serial.txt") + save_cls_res(cls_res_all, img_paths, include_score=True, save_path=save_fp) + # print('All cls res: ', cls_res_all) + print("Done! Text direction classification results saved in ", save_dir) + print("Time cost: ", t, "FPS: ", len(img_paths) / t) diff --git a/tools/infer/text/predict_det.py b/tools/infer/text/predict_det.py index f8c6cdd62..0a1730af9 100644 --- a/tools/infer/text/predict_det.py +++ b/tools/infer/text/predict_det.py @@ -2,7 +2,7 @@ Text detection inference Example: - $ python tools/infer/text/predict_det.py --image_dir {path_to_img} --rec_algorithm DB++ + $ python tools/infer/text/predict_det.py --image_dir {img_dir_or_img_path} --rec_algorithm DB++ """ import json @@ -222,7 +222,6 @@ def save_det_res(det_res_all: List[dict], img_paths: List[str], include_score=Fa with open(save_path, "w") as f: f.writelines(lines) - f.close() if __name__ == "__main__": diff --git a/tools/infer/text/predict_rec.py b/tools/infer/text/predict_rec.py index e7251ef40..0da119b6b 100644 --- a/tools/infer/text/predict_rec.py +++ b/tools/infer/text/predict_rec.py @@ -2,8 +2,8 @@ Text recognition inference Example: - $ python tools/infer/text/predict_rec.py --image_dir {path_to_img} --rec_algorithm CRNN - $ python tools/infer/text/predict_rec.py --image_dir {path_to_img} --rec_algorithm CRNN_CH + $ python tools/infer/text/predict_rec.py --image_dir {img_dir_or_img_path} --rec_algorithm CRNN + $ python tools/infer/text/predict_rec.py --image_dir {img_dir_or_img_path} --rec_algorithm CRNN_CH """ import os import sys @@ -246,14 +246,13 @@ def save_rec_res(rec_res_all, img_paths, include_score=False, save_path="./rec_r lines = [] for i, rec_res in enumerate(rec_res_all): if include_score: - img_pred = os.path.basename(img_paths[i]) + "\t" + rec_res[0] + "\t" + rec_res[1] + "\n" + img_pred = os.path.basename(img_paths[i]) + "\t" + rec_res[0] + "\t" + str(rec_res[1]) + "\n" else: img_pred = os.path.basename(img_paths[i]) + "\t" + rec_res[0] + "\n" lines.append(img_pred) with open(save_path, "w") as f: f.writelines(lines) - f.close() return lines diff --git a/tools/infer/text/predict_system.py b/tools/infer/text/predict_system.py index 3a76e6d4c..581a02306 100644 --- a/tools/infer/text/predict_system.py +++ b/tools/infer/text/predict_system.py @@ -2,10 +2,12 @@ Text detection and recognition inference Example: - $ python tools/infer/text/predict_system.py --image_dir {path_to_img_file} --det_algorithm DB++ \ + $ python tools/infer/text/predict_system.py --image_dir {img_dir_or_img_path} --det_algorithm DB++ \ --rec_algorithm CRNN - $ python tools/infer/text/predict_system.py --image_dir {path_to_img_dir} --det_algorithm DB++ \ + $ python tools/infer/text/predict_system.py --image_dir {img_dir_or_img_path} --det_algorithm DB++ \ --rec_algorithm CRNN_CH + $ python tools/infer/text/predict_system.py --image_dir {img_dir_or_img_path} --det_algorithm DB++ \ + --cls_algorithm MV3 --rec_algorithm CRNN_CH """ import json @@ -17,6 +19,7 @@ import cv2 import numpy as np from config import parse_args +from predict_cls import DirectionClassifier from predict_det import TextDetector from predict_rec import TextRecognizer from utils import crop_text_region, get_image_paths @@ -36,7 +39,10 @@ class TextSystem(object): def __init__(self, args): self.text_detect = TextDetector(args) self.text_recognize = TextRecognizer(args) + if args.cls_algorithm: + self.direction_classify = DirectionClassifier(args) + self.cls_algorithm = args.cls_algorithm self.box_type = args.det_box_type self.drop_score = args.drop_score self.save_crop_res = args.save_crop_res @@ -86,6 +92,12 @@ def __call__(self, img_or_path: Union[str, np.ndarray], do_visualize=True): cv2.imwrite(os.path.join(self.crop_res_save_dir, f"{fn}_crop_{i}.jpg"), cropped_img) # show_imgs(crops, is_bgr_img=False) + # classify the text direction of cropped images + # TODO: support cls run_single, and then put cls into crop process + if self.cls_algorithm: + _, crops = self.direction_classify(crops) + # import pdb; pdb.set_trace() + # recognize cropped images rs = time() rec_res_all_crops = self.text_recognize(crops, do_visualize=False) @@ -162,6 +174,7 @@ def main(): # img_paths = img_paths[:10] ms.set_context(mode=args.mode) + ms.set_context(device_id=5) # init text system with detector and recognizer text_spot = TextSystem(args) diff --git a/tools/infer/text/preprocess.py b/tools/infer/text/preprocess.py index 219559232..b5be148e8 100644 --- a/tools/infer/text/preprocess.py +++ b/tools/infer/text/preprocess.py @@ -54,7 +54,7 @@ def __init__(self, task="det", algo="DB", **kwargs): # TODO: modify the base pipeline for non-DBNet network if needed # if algo == 'DB++': # pipeline[1]['DetResize']['limit_side_len'] = 1152 - elif task == "rec": + elif task in ("rec", "cls"): # defalut value if not claim in optim_hparam DEFAULT_PADDING = True DEFAULT_KEEP_RATIO = True @@ -64,29 +64,35 @@ def __init__(self, task="det", algo="DB", **kwargs): # register optimal hparam for each model optimal_hparam = { - # 'CRNN': dict(target_height=32, target_width=100, padding=True, keep_ratio=True, norm_before_pad=True), + # "CRNN": dict(target_height=32, target_width=100, padding=True, keep_ratio=True, norm_before_pad=True), "CRNN": dict(target_height=32, target_width=100, padding=False, keep_ratio=False), - "CRNN_CH": dict(target_height=32, taget_width=320, padding=True, keep_ratio=True), + "CRNN_CH": dict(target_height=32, target_width=320, padding=True, keep_ratio=True), "RARE": dict(target_height=32, target_width=100, padding=False, keep_ratio=False), "RARE_CH": dict(target_height=32, target_width=320, padding=True, keep_ratio=True), "SVTR": dict(target_height=64, target_width=256, padding=False, keep_ratio=False), + "MV3": dict(target_height=48, target_width=192, padding=False, keep_ratio=False), } - # get hparam by combining default value, optimal value, and arg parser value. Prior: optimal value -> - # parser value -> default value - parsed_img_shape = kwargs.get("rec_image_shape", "3, 32, 320").split(",") + # get hparam by combining default value, optimal value, and arg parser value. + # priority: optimal value -> parser value -> default value + if task == "cls": + parsed_img_shape = kwargs.get("cls_image_shape", "3, 48, 192").split(",") + batch_mode = kwargs.get("cls_batch_mode", False) # and (batch_num > 1) + else: + parsed_img_shape = kwargs.get("rec_image_shape", "3, 32, 100").split(",") + batch_mode = kwargs.get("rec_batch_mode", False) # and (batch_num > 1) + parsed_height, parsed_width = int(parsed_img_shape[1]), int(parsed_img_shape[2]) if algo in optimal_hparam: target_height = optimal_hparam[algo]["target_height"] + norm_before_pad = optimal_hparam[algo].get("norm_before_pad", DEFAULT_NORM_BEFORE_PAD) else: target_height = parsed_height - - norm_before_pad = optimal_hparam[algo].get("norm_before_pad", DEFAULT_NORM_BEFORE_PAD) + norm_before_pad = DEFAULT_NORM_BEFORE_PAD # TODO: update max_wh_ratio for each batch # max_wh_ratio = parsed_width / float(parsed_height) # batch_num = kwargs.get('rec_batch_num', 1) - batch_mode = kwargs.get("rec_batch_mode", False) # and (batch_num > 1) if not batch_mode: # For single infer, the optimal choice is to resize the image to target height while keeping # aspect ratio, no padding. limit the max width. @@ -106,13 +112,14 @@ def __init__(self, task="det", algo="DB", **kwargs): if (target_height != parsed_height) or (target_width != parsed_width): _logger.warning( - f"`rec_image_shape` {parsed_img_shape[1:]} dose not meet the network input requirement or " - f"is not optimal, which should be [{target_height}, {target_width}] under batch mode = {batch_mode}" + f"`{task}_image_shape` {parsed_img_shape[1:]} dose not meet the network input requirement " + f"or is not optimal, which should be [{target_height}, {target_width}] under " + f"batch mode = {batch_mode}." ) _logger.info( - f"Pick optimal preprocess hyper-params for rec algo {algo}:\n" - + "\n".join( + f"Pick optimal preprocess hyper-params for {task} algo {algo}:\n", + "\n".join( [ f"{k}:\t{str(v)}" for k, v in dict( From 8d0861d24fde5956907ef4da4808bb4db7871cd5 Mon Sep 17 00:00:00 2001 From: HaoyangLI <417493727@qq.com> Date: Wed, 21 Jun 2023 00:27:23 +0800 Subject: [PATCH 3/8] update predict_cls, predict_system --- .flake8 | 1 + mindocr/utils/callbacks.py | 1 + tools/infer/text/predict_cls.py | 55 ++++++++++++++++-------------- tools/infer/text/predict_rec.py | 1 + tools/infer/text/predict_system.py | 8 ++--- tools/infer/text/preprocess.py | 2 +- 6 files changed, 37 insertions(+), 31 deletions(-) diff --git a/.flake8 b/.flake8 index f59d3e5f6..f98e9e526 100644 --- a/.flake8 +++ b/.flake8 @@ -18,6 +18,7 @@ per-file-ignores = tools/infer/text/parallel/base_predict.py:E402 tools/infer/text/parallel/predict_system.py:E402 tools/infer/text/predict_system.py:E402 + tools/infer/text/predict_cls.py:E402 tools/infer/text/predict_rec.py:E402 tools/dataset_converters/convert.py:F401,F403 mindocr/data/transforms/transforms_factory.py:F401,F403 diff --git a/mindocr/utils/callbacks.py b/mindocr/utils/callbacks.py index dff2dcc74..34f1d870a 100644 --- a/mindocr/utils/callbacks.py +++ b/mindocr/utils/callbacks.py @@ -20,6 +20,7 @@ from mindspore import jit else: from mindspore import ms_function + jit = ms_function diff --git a/tools/infer/text/predict_cls.py b/tools/infer/text/predict_cls.py index 1602e5d95..f34730de5 100644 --- a/tools/infer/text/predict_cls.py +++ b/tools/infer/text/predict_cls.py @@ -20,12 +20,14 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../../"))) -from mindocr import build_model # noqa -from mindocr.utils.visualize import show_imgs # noqa +from mindocr import build_model +from mindocr.utils.logger import Logger +from mindocr.utils.visualize import show_imgs # map algorithm name to model name (which can be checked by `mindocr.list_models()`) # NOTE: Modify it to add new model for inference. algo_to_model_name = {"MV3": "cls_mobilenet_v3_small_100_model"} +_logger = Logger("mindocr") class DirectionClassifier(object): @@ -35,8 +37,8 @@ def __init__(self, args): self.rotate_thre = args.cls_rotate_thre self.visualize_output = args.visualize_output # self.batch_mode = args.cls_batch_mode and (self.batch_num > 1) - print( - "INFO: recognize in {} mode {}".format( + _logger.info( + "recognize in {} mode {}".format( "batch" if self.batch_mode else "serial", "batch_size: " + str(self.batch_num) if self.batch_mode else "", ) @@ -56,8 +58,8 @@ def __init__(self, args): model_name = algo_to_model_name[args.cls_algorithm] self.model = build_model(model_name, pretrained=pretrained, ckpt_load_path=ckpt_load_path) self.model.set_train(False) - print( - "INFO: Init text direction classification model: {} --> {}. Model weights loaded from {}".format( + _logger.info( + "Init text direction classification model: {} --> {}. Model weights loaded from {}".format( args.cls_algorithm, model_name, "pretrained url" if pretrained else ckpt_load_path ) ) @@ -97,18 +99,28 @@ def __call__(self, img_or_path_list: list): assert isinstance( img_or_path_list, list ), "Input for text direction classification must be list of images or image paths." - print("INFO: num images for cls: ", len(img_or_path_list)) + _logger.info("num images for cls: ", len(img_or_path_list)) if self.batch_mode: - cls_res_all_crops, all_rotated_imgs = self.run_batchwise(img_or_path_list) + cls_res_all, all_rotated_imgs = self.run_batchwise(img_or_path_list) else: - cls_res_all_crops = [] + cls_res_all, all_rotated_imgs = [], [] for i, img_or_path in enumerate(img_or_path_list): - cls_res = self.run_single(img_or_path, i) - cls_res_all_crops.append(cls_res) + cls_res, rotated_imgs = self.run_single(img_or_path, i) + cls_res_all.append(cls_res) + all_rotated_imgs.extend(rotated_imgs) - # TODO: 加vis和save功能 + # TODO: add vis and save function + return cls_res_all, all_rotated_imgs - return cls_res_all_crops, all_rotated_imgs + def rotate(self, img_batch, batch_res): + rotated_img_batch = [] + for i, score in enumerate(batch_res["scores"]): + tmp_img = img_batch[i] + if int(batch_res["angles"][i]) != 0 and score > self.rotate_thre: + tmp_img = np.rot90(tmp_img, k=int(int(batch_res["angles"][i]) / 90)) + _logger.info(f"After text direction classification, image is rotated {batch_res['angles'][i]} degree.") + rotated_img_batch.append(tmp_img.transpose(1, 2, 0)) # c, h, w --> h, w, c for saving and visualization + return rotated_img_batch def run_batchwise(self, img_or_path_list: list): """ @@ -165,15 +177,6 @@ def run_batchwise(self, img_or_path_list: list): return cls_res, all_rotated_imgs - def rotate(self, img_batch, batch_res): - rotated_img_batch = [] - for i, score in enumerate(batch_res["scores"]): - tmp_img = img_batch[i] - if score > self.rotate_thre: - tmp_img = np.rot90(tmp_img, k=int(int(batch_res["angles"][i]) / 90)) - rotated_img_batch.append(tmp_img.transpose(1, 2, 0)) # c, h, w --> h, w, c for saving and visualization - return rotated_img_batch - def run_single(self, img_or_path, crop_idx=0): """ Text direction classification inference on a single image @@ -213,12 +216,12 @@ def run_single(self, img_or_path, crop_idx=0): # postprocess cls_res = self.postprocess(net_output) - + net_input_rot = self.rotate(net_input, cls_res) cls_res = (cls_res["angles"][0], cls_res["scores"][0]) print(f"Crop {crop_idx} cls result:", cls_res) - return cls_res + return cls_res, net_input_rot def save_cls_res(cls_res_all, img_paths, include_score=False, save_path="./cls_results.txt"): @@ -245,7 +248,7 @@ def save_cls_res(cls_res_all, img_paths, include_score=False, save_path="./cls_r # img_paths = img_paths[:250] ms.set_context(mode=args.mode) - ms.set_context(device_id=5) + ms.set_context(device_id=7) # init classifier classifier = DirectionClassifier(args) @@ -254,7 +257,7 @@ def save_cls_res(cls_res_all, img_paths, include_score=False, save_path="./cls_r # run for each image start = time() - cls_res_all = classifier(img_paths) + cls_res_all, _ = classifier(img_paths) t = time() - start # save all results in a txt file save_fp = os.path.join(save_dir, "cls_results.txt" if args.cls_batch_mode else "cls_results_serial.txt") diff --git a/tools/infer/text/predict_rec.py b/tools/infer/text/predict_rec.py index 0da119b6b..1a7471b08 100644 --- a/tools/infer/text/predict_rec.py +++ b/tools/infer/text/predict_rec.py @@ -266,6 +266,7 @@ def save_rec_res(rec_res_all, img_paths, include_score=False, save_path="./rec_r # img_paths = img_paths[:250] ms.set_context(mode=args.mode) + ms.set_context(device_id=6) # init detector text_recognize = TextRecognizer(args) diff --git a/tools/infer/text/predict_system.py b/tools/infer/text/predict_system.py index 581a02306..3576567ae 100644 --- a/tools/infer/text/predict_system.py +++ b/tools/infer/text/predict_system.py @@ -55,7 +55,7 @@ def __init__(self, args): def __call__(self, img_or_path: Union[str, np.ndarray], do_visualize=True): """ - Detect and recognize texts in an image + Detect, (classify direction) and recognize texts in an image Args: img_or_path (str or np.ndarray): path to image or image rgb values as a numpy array @@ -93,10 +93,10 @@ def __call__(self, img_or_path: Union[str, np.ndarray], do_visualize=True): # show_imgs(crops, is_bgr_img=False) # classify the text direction of cropped images - # TODO: support cls run_single, and then put cls into crop process if self.cls_algorithm: + cls_start = time() _, crops = self.direction_classify(crops) - # import pdb; pdb.set_trace() + time_profile["cls"] = time() - cls_start # recognize cropped images rs = time() @@ -185,7 +185,7 @@ def main(): text_spot(img_paths[0], do_visualize=False) # run - tot_time = {} # {'det': 0, 'rec': 0, 'all': 0} + tot_time = {} # {'det': 0, 'rec': 0, 'cls': 0, 'all': 0} boxes_all, text_scores_all = [], [] for i, img_path in enumerate(img_paths): print(f"\nINFO: Infering [{i+1}/{len(img_paths)}]: ", img_path) diff --git a/tools/infer/text/preprocess.py b/tools/infer/text/preprocess.py index b5be148e8..dbd64d019 100644 --- a/tools/infer/text/preprocess.py +++ b/tools/infer/text/preprocess.py @@ -130,7 +130,7 @@ def __init__(self, task="det", algo="DB", **kwargs): norm_before_pad=norm_before_pad, ).items() ] - ) + ), ) pipeline = [ From 116180f7854a5eca5f1b4def5a5b0af7b05e32ff Mon Sep 17 00:00:00 2001 From: HaoyangLI <417493727@qq.com> Date: Wed, 21 Jun 2023 00:52:22 +0800 Subject: [PATCH 4/8] update RecResizeNormForInfer to support padding=False, keep_ratio=True, target_width=None when online predict --- mindocr/data/transforms/rec_transforms.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mindocr/data/transforms/rec_transforms.py b/mindocr/data/transforms/rec_transforms.py index dfe697b5f..7be0bdd97 100644 --- a/mindocr/data/transforms/rec_transforms.py +++ b/mindocr/data/transforms/rec_transforms.py @@ -411,14 +411,16 @@ def __call__(self, data): # tar_h, tar_w = self.targt_shape resize_h = self.tar_h - max_wh_ratio = self.tar_w / float(self.tar_h) - if not self.keep_ratio: assert self.tar_w is not None, "Must specify target_width if keep_ratio is False" resize_w = self.tar_w # if self.tar_w is not None else resized_h * self.max_wh_ratio else: src_wh_ratio = w / float(h) - resize_w = math.ceil(min(src_wh_ratio, max_wh_ratio) * resize_h) + if self.tar_w is not None: + max_wh_ratio = self.tar_w / float(self.tar_h) + resize_w = math.ceil(min(src_wh_ratio, max_wh_ratio) * resize_h) + else: + resize_w = math.ceil(src_wh_ratio * resize_h) resized_img = cv2.resize(img, (resize_w, resize_h), interpolation=self.interpolation) # TODO: norm before padding From fc87614a3b870ca955fd3e6e5ef1fe95fba177d6 Mon Sep 17 00:00:00 2001 From: HaoyangLI <417493727@qq.com> Date: Wed, 21 Jun 2023 00:55:24 +0800 Subject: [PATCH 5/8] remove device_id setting in online predict --- tools/infer/text/predict_cls.py | 1 - tools/infer/text/predict_rec.py | 1 - tools/infer/text/predict_system.py | 1 - 3 files changed, 3 deletions(-) diff --git a/tools/infer/text/predict_cls.py b/tools/infer/text/predict_cls.py index f34730de5..d070b8afc 100644 --- a/tools/infer/text/predict_cls.py +++ b/tools/infer/text/predict_cls.py @@ -248,7 +248,6 @@ def save_cls_res(cls_res_all, img_paths, include_score=False, save_path="./cls_r # img_paths = img_paths[:250] ms.set_context(mode=args.mode) - ms.set_context(device_id=7) # init classifier classifier = DirectionClassifier(args) diff --git a/tools/infer/text/predict_rec.py b/tools/infer/text/predict_rec.py index 1a7471b08..0da119b6b 100644 --- a/tools/infer/text/predict_rec.py +++ b/tools/infer/text/predict_rec.py @@ -266,7 +266,6 @@ def save_rec_res(rec_res_all, img_paths, include_score=False, save_path="./rec_r # img_paths = img_paths[:250] ms.set_context(mode=args.mode) - ms.set_context(device_id=6) # init detector text_recognize = TextRecognizer(args) diff --git a/tools/infer/text/predict_system.py b/tools/infer/text/predict_system.py index 3576567ae..a5404b4a8 100644 --- a/tools/infer/text/predict_system.py +++ b/tools/infer/text/predict_system.py @@ -174,7 +174,6 @@ def main(): # img_paths = img_paths[:10] ms.set_context(mode=args.mode) - ms.set_context(device_id=5) # init text system with detector and recognizer text_spot = TextSystem(args) From 010802d055378eb556c810e63cfcce64069f4d63 Mon Sep 17 00:00:00 2001 From: HaoyangLI <417493727@qq.com> Date: Wed, 21 Jun 2023 01:14:21 +0800 Subject: [PATCH 6/8] add test case for predict_cls and det+cls+rec predict --- tests/st/test_online_infer.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/st/test_online_infer.py b/tests/st/test_online_infer.py index 1e17cb383..5b92e88b4 100644 --- a/tests/st/test_online_infer.py +++ b/tests/st/test_online_infer.py @@ -42,6 +42,7 @@ def _gen_text_image(texts=TEXTS_2, boxes=BOXES_2, save_fp="gen_img.jpg"): det_img_fp = _gen_text_image(save_fp="gen_det_input.jpg") rec_img_fp = _gen_text_image([TEXTS_2[0]], [BOXES_2[0]], "gen_rec_input.jpg") +cls_img_fp = rec_img_fp def test_det_infer(): @@ -55,6 +56,17 @@ def test_det_infer(): assert ret == 0, "Det inference fails" +def test_cls_infer(): + algo = "MV3" + cmd = ( + f"python tools/infer/text/predict_cls.py --image_dir {cls_img_fp} --cls_algorithm {algo} " + f"--draw_img_save_dir ./infer_test" + ) + print(f"Running command: \n{cmd}") + ret = subprocess.call(cmd.split(), stdout=sys.stdout, stderr=sys.stderr) + assert ret == 0, "Cls inference fails" + + def test_rec_infer(): algo = "CRNN" cmd = ( @@ -68,10 +80,12 @@ def test_rec_infer(): def test_system_infer(): det_algo = "DB" + cls_algo = "MV3" rec_algo = "CRNN_CH" cmd = ( f"python tools/infer/text/predict_system.py --image_dir {det_img_fp} --det_algorithm {det_algo} " - f"--rec_algorithm {rec_algo} --draw_img_save_dir ./infer_test --visualize_output True" + f"--cls_algorithm {cls_algo} --rec_algorithm {rec_algo} " + f"--draw_img_save_dir ./infer_test --visualize_output True" ) print(f"Running command: \n{cmd}") ret = subprocess.call(cmd.split(), stdout=sys.stdout, stderr=sys.stderr) From 2a748740ad55ab2dfbd1108b2f3df9b63d4cea89 Mon Sep 17 00:00:00 2001 From: HaoyangLI <417493727@qq.com> Date: Wed, 21 Jun 2023 01:30:01 +0800 Subject: [PATCH 7/8] update command example in predict_system --- tools/infer/text/predict_system.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/infer/text/predict_system.py b/tools/infer/text/predict_system.py index a5404b4a8..80b90029b 100644 --- a/tools/infer/text/predict_system.py +++ b/tools/infer/text/predict_system.py @@ -7,7 +7,7 @@ $ python tools/infer/text/predict_system.py --image_dir {img_dir_or_img_path} --det_algorithm DB++ \ --rec_algorithm CRNN_CH $ python tools/infer/text/predict_system.py --image_dir {img_dir_or_img_path} --det_algorithm DB++ \ - --cls_algorithm MV3 --rec_algorithm CRNN_CH + --cls_algorithm MV3 --rec_algorithm CRNN_CH --visualize_output True """ import json From 485a5c844cb432fbef1b6a231b7816bf9e1e32d6 Mon Sep 17 00:00:00 2001 From: HaoyangLI <417493727@qq.com> Date: Wed, 21 Jun 2023 10:16:16 +0800 Subject: [PATCH 8/8] small fix --- tools/infer/text/predict_cls.py | 6 +++--- tools/infer/text/predict_system.py | 1 + tools/infer/text/preprocess.py | 26 ++++++++++++-------------- 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/tools/infer/text/predict_cls.py b/tools/infer/text/predict_cls.py index d070b8afc..576cbdc61 100644 --- a/tools/infer/text/predict_cls.py +++ b/tools/infer/text/predict_cls.py @@ -38,7 +38,7 @@ def __init__(self, args): self.visualize_output = args.visualize_output # self.batch_mode = args.cls_batch_mode and (self.batch_num > 1) _logger.info( - "recognize in {} mode {}".format( + "classify text direction in {} mode {}".format( "batch" if self.batch_mode else "serial", "batch_size: " + str(self.batch_num) if self.batch_mode else "", ) @@ -99,7 +99,7 @@ def __call__(self, img_or_path_list: list): assert isinstance( img_or_path_list, list ), "Input for text direction classification must be list of images or image paths." - _logger.info("num images for cls: ", len(img_or_path_list)) + _logger.info(f"num images for cls: {len(img_or_path_list)}") if self.batch_mode: cls_res_all, all_rotated_imgs = self.run_batchwise(img_or_path_list) else: @@ -263,4 +263,4 @@ def save_cls_res(cls_res_all, img_paths, include_score=False, save_path="./cls_r save_cls_res(cls_res_all, img_paths, include_score=True, save_path=save_fp) # print('All cls res: ', cls_res_all) print("Done! Text direction classification results saved in ", save_dir) - print("Time cost: ", t, "FPS: ", len(img_paths) / t) + print("CLS time: ", t, "FPS: ", len(img_paths) / t) diff --git a/tools/infer/text/predict_system.py b/tools/infer/text/predict_system.py index 80b90029b..4479b89cf 100644 --- a/tools/infer/text/predict_system.py +++ b/tools/infer/text/predict_system.py @@ -97,6 +97,7 @@ def __call__(self, img_or_path: Union[str, np.ndarray], do_visualize=True): cls_start = time() _, crops = self.direction_classify(crops) time_profile["cls"] = time() - cls_start + _logger.info(f"\nCls time: {time_profile['cls']}") # recognize cropped images rs = time() diff --git a/tools/infer/text/preprocess.py b/tools/infer/text/preprocess.py index dbd64d019..ba4cb1ce7 100644 --- a/tools/infer/text/preprocess.py +++ b/tools/infer/text/preprocess.py @@ -117,21 +117,19 @@ def __init__(self, task="det", algo="DB", **kwargs): f"batch mode = {batch_mode}." ) - _logger.info( - f"Pick optimal preprocess hyper-params for {task} algo {algo}:\n", - "\n".join( - [ - f"{k}:\t{str(v)}" - for k, v in dict( - target_height=target_height, - target_width=target_width, - padding=padding, - keep_ratio=keep_ratio, - norm_before_pad=norm_before_pad, - ).items() - ] - ), + hparam = "\n".join( + [ + f"{k}:\t{str(v)}" + for k, v in dict( + target_height=target_height, + target_width=target_width, + padding=padding, + keep_ratio=keep_ratio, + norm_before_pad=norm_before_pad, + ).items() + ] ) + _logger.info(f"Pick optimal preprocess hyper-params for {task} algo {algo}:\n{hparam}") pipeline = [ {"DecodeImage": {"img_mode": "RGB", "keep_ori": True, "to_float32": False}},