Skip to content

Commit e282060

Browse files
committed
1.Refactor the code of 'predict_from_yaml.py' 2.Added concat_crops function.
1 parent 4bca517 commit e282060

File tree

1 file changed

+216
-76
lines changed

1 file changed

+216
-76
lines changed

tools/infer/text/predict_from_yaml.py

+216-76
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
import logging
1010
import os
1111
import sys
12+
from time import time
1213

14+
import cv2
15+
import numpy as np
1316
import yaml
1417
from addict import Dict
1518
from PIL import Image
@@ -22,12 +25,14 @@
2225
from mindspore import Tensor, get_context, set_auto_parallel_context, set_context
2326
from mindspore.communication import get_group_size, get_rank, init
2427

28+
from deploy.py_infer.src.infer_args import str2bool # noqa
2529
from mindocr.data import build_dataset
2630
from mindocr.data.transforms import create_transforms, run_transforms
2731
from mindocr.models import build_model
2832
from mindocr.postprocess import build_postprocess
2933
from mindocr.utils.visualize import draw_boxes, show_imgs
3034
from tools.arg_parser import _merge_options, _parse_options
35+
from tools.infer.text.utils import get_image_paths
3136
from tools.modelarts_adapter.modelarts import modelarts_setup
3237

3338
__dir__ = os.path.dirname(os.path.abspath(__file__))
@@ -155,21 +160,7 @@ def predict_single_step(cfg, save_res=True):
155160
)
156161

157162
# 3.Build model
158-
amp_level = cfg.system.get("amp_level_infer", "O0")
159-
if get_context("device_target") == "GPU" and amp_level == "O3":
160-
logger.warning(
161-
"Model evaluation does not support amp_level O3 on GPU currently. "
162-
"The program has switched to amp_level O2 automatically."
163-
)
164-
amp_level = "O2"
165-
cfg.model.backbone.pretrained = False
166-
if cfg.predict.ckpt_load_path is None:
167-
logger.warning(
168-
f"No ckpt is available for {cfg.model.task}, "
169-
"please check your configuration of 'predict.ckpt_load_path' in the yaml file."
170-
)
171-
network = build_model(cfg.model, ckpt_load_path=cfg.predict.ckpt_load_path, amp_level=amp_level)
172-
network.set_train(False)
163+
network = build_model_from_config(cfg)
173164

174165
# 4.Build postprocessor for network output
175166
postprocessor = build_postprocess(cfg.postprocess)
@@ -230,72 +221,220 @@ def predict_single_step(cfg, save_res=True):
230221
return preds_list
231222

232223

233-
def predict_system(args, det_cfg, rec_cfg):
234-
"""Run predict for both det and rec task"""
235-
# merge image_dir option in model config
236-
det_cfg.predict.dataset.dataset_root = ""
237-
det_cfg.predict.dataset.data_dir = args.image_dir
238-
output_save_dir = det_cfg.predict.output_save_dir or "./output"
239-
240-
# get det result from predict
241-
preds_list = predict_single_step(det_cfg, save_res=False)
242-
243-
# set amp level
244-
amp_level = det_cfg.system.get("amp_level_infer", "O0")
224+
def build_model_from_config(cfg):
225+
amp_level = cfg.system.get("amp_level_infer", "O0")
245226
if get_context("device_target") == "GPU" and amp_level == "O3":
246227
logger.warning(
247228
"Model evaluation does not support amp_level O3 on GPU currently. "
248229
"The program has switched to amp_level O2 automatically."
249230
)
250231
amp_level = "O2"
251-
252-
# create preprocess and postprocess for rec task
253-
transforms = create_transforms(rec_cfg.predict.dataset.transform_pipeline)
254-
postprocessor = build_postprocess(rec_cfg.postprocess)
255-
256-
# build rec model from yaml
257-
rec_cfg.model.backbone.pretrained = False
258-
if rec_cfg.predict.ckpt_load_path is None:
232+
cfg.model.backbone.pretrained = False
233+
if cfg.predict.ckpt_load_path is None:
259234
logger.warning(
260-
f"No ckpt is available for {rec_cfg.model.type}, "
235+
f"No ckpt is available for {cfg.model.task}, "
261236
"please check your configuration of 'predict.ckpt_load_path' in the yaml file."
262237
)
263-
rec_network = build_model(rec_cfg.model, ckpt_load_path=rec_cfg.predict.ckpt_load_path, amp_level=amp_level)
264-
265-
# start rec task
266-
logger.info("Start rec")
267-
img_list = [] # list of img_path
268-
boxes_all = [] # list of boxes of all image
269-
text_scores_all = [] # list of text and scores of all image
270-
for preds_batch in tqdm(preds_list):
271-
# preds_batch is a dictionary of det prediction output, which contains det information of a batch
272-
preds_batch["texts"] = []
273-
preds_batch["confs"] = []
274-
for i, crops in enumerate(preds_batch["crops"]):
275-
# A batch may contain multiple images
276-
img_path = preds_batch["img_path"][i]
277-
img_box = []
278-
img_text_scores = []
279-
for j, crop in enumerate(crops):
280-
# For each image, it may contain several crops
281-
data = {"image": crop}
282-
data["image_ori"] = crop.copy()
283-
data["image_shape"] = crop.shape
284-
data = run_transforms(data, transforms[1:])
285-
data = rec_network(Tensor(data["image"]).expand_dims(0))
286-
out = postprocessor(data)
287-
confs = out["confs"][0]
288-
if confs > 0.5:
289-
# Keep text with a confidence greater than 0.5
290-
box = preds_batch["polys"][i][j]
291-
text = out["texts"][0]
292-
img_box.append(box)
293-
img_text_scores.append((text, confs))
294-
# Each image saves its path, box and texts_scores
295-
img_list.append(img_path)
296-
boxes_all.append(img_box)
297-
text_scores_all.append(img_text_scores)
298-
save_res(boxes_all, text_scores_all, img_list, save_path=os.path.join(output_save_dir, "system_results.txt"))
238+
network = build_model(cfg.model, ckpt_load_path=cfg.predict.ckpt_load_path, amp_level=amp_level)
239+
network.set_train(False)
240+
return network
241+
242+
243+
def sort_polys(polys):
244+
return sorted(polys, key=lambda points: (points[0][1], points[0][0]))
245+
246+
247+
def concat_crops(crops: list):
248+
max_height = max(crop.shape[0] for crop in crops)
249+
resized_crops = []
250+
for crop in crops:
251+
h, w, c = crop.shape
252+
new_h = max_height
253+
new_w = int((w / h) * new_h)
254+
255+
resized_img = cv2.resize(crop, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
256+
resized_crops.append(resized_img)
257+
crops = np.concatenate(resized_crops, axis=1)
258+
return crops
259+
260+
261+
class Predict_System:
262+
def __init__(self, det_cfg, rec_cfg, is_concat=False):
263+
for transform in det_cfg.predict.dataset.transform_pipeline:
264+
if "DecodeImage" in transform:
265+
transform["DecodeImage"].update({"keep_ori": True})
266+
break
267+
self.det_transforms = create_transforms(det_cfg.predict.dataset.transform_pipeline)
268+
self.det_model = build_model_from_config(det_cfg)
269+
self.det_postprocess = build_postprocess(det_cfg.postprocess)
270+
271+
self.rec_batch_size = rec_cfg.predict.loader.batch_size
272+
self.rec_preprocess = create_transforms(rec_cfg.predict.dataset.transform_pipeline)
273+
self.rec_model = build_model_from_config(rec_cfg)
274+
self.rec_postprocess = build_postprocess(rec_cfg.postprocess)
275+
276+
self.is_concat = is_concat
277+
278+
def predict_rec(self, crops: list):
279+
"""
280+
Run text recognition serially for input images
281+
282+
Args:
283+
img_or_path_list: list of str for img path or np.array for RGB image
284+
do_visualize: visualize preprocess and final result and save them
285+
286+
Return:
287+
rec_res: list of tuple, where each tuple is (text, score) - text recognition result for each input image
288+
in order.
289+
where text is the predicted text string, score is its confidence score.
290+
e.g. [('apple', 0.9), ('bike', 1.0)]
291+
"""
292+
rec_res = []
293+
num_crops = len(crops)
294+
295+
for idx in range(0, num_crops, self.rec_batch_size): # batch begin index i
296+
batch_begin = idx
297+
batch_end = min(idx + self.rec_batch_size, num_crops)
298+
logger.info(f"Rec img idx range: [{batch_begin}, {batch_end})")
299+
# TODO: set max_wh_ratio to the maximum wh ratio of images in the batch. and update it for resize,
300+
# which may improve recognition accuracy in batch-mode
301+
# especially for long text image. max_wh_ratio=max(max_wh_ratio, img_w / img_h).
302+
# The short ones should be scaled with a.r. unchanged and padded to max width in batch.
303+
304+
# preprocess
305+
# TODO: run in parallel with multiprocessing
306+
img_batch = []
307+
for j in range(batch_begin, batch_end): # image index j
308+
data = run_transforms({"image": crops[j]}, self.rec_preprocess[1:])
309+
img_batch.append(data["image"])
310+
311+
img_batch = np.stack(img_batch) if len(img_batch) > 1 else np.expand_dims(img_batch[0], axis=0)
312+
313+
# infer
314+
net_pred = self.rec_model(Tensor(img_batch))
315+
316+
# postprocess
317+
batch_res = self.rec_postprocess(net_pred)
318+
rec_res.extend(list(zip(batch_res["texts"], batch_res["confs"])))
319+
320+
return rec_res
321+
322+
def predict(self, img_path):
323+
"""
324+
Detect and recognize texts in an image
325+
326+
Args:
327+
img_or_path (str or np.ndarray): path to image or image rgb values as a numpy array
328+
329+
Return:
330+
boxes (list): detected text boxes, in shape [num_boxes, num_points, 2], where the point coordinate (x, y)
331+
follows: x - horizontal (image width direction), y - vertical (image height)
332+
texts (list[tuple]): list of (text, score) where text is the recognized text string for each box,
333+
and score is the confidence score.
334+
time_profile (dict): record the time cost for each sub-task.
335+
"""
336+
337+
time_profile = {}
338+
start = time()
339+
340+
# detect text regions on an image
341+
data = {"img_path": img_path}
342+
data = run_transforms(data, self.det_transforms)
343+
input_np = np.expand_dims(data["image"], axis=0)
344+
logits = self.det_model(Tensor(input_np))
345+
pred = self.det_postprocess(logits, shape_list=np.expand_dims(data["shape_list"], axis=0))
346+
polys = pred["polys"][0]
347+
scores = pred["scores"][0]
348+
pred = dict(polys=polys, scores=scores)
349+
det_res = validate_det_res(pred, data["image_ori"].shape[:2], min_poly_points=3, min_area=3)
350+
det_res["img_ori"] = data["image_ori"]
351+
352+
time_profile["det"] = time() - start
353+
polys = det_res["polys"].copy()
354+
if len(polys) == 0:
355+
logger.warning(f"No text detected in {img_path}")
356+
time_profile["rec"] = 0.0
357+
time_profile["all"] = time_profile["det"]
358+
return [], [], time_profile
359+
polys = sort_polys(polys)
360+
logger.info(f"Num detected text boxes: {len(polys)}\nDet time: {time_profile['det']}")
361+
if self.is_concat:
362+
logger.info("After concatenating, 1 croped image will be recognized.")
363+
364+
# crop text regions
365+
crops = []
366+
for i in range(len(polys)):
367+
poly = polys[i].astype(np.float32)
368+
cropped_img = crop_text_region(data["image_ori"], poly, box_type=det_cfg.postprocess.box_type)
369+
crops.append(cropped_img)
370+
371+
# if self.save_crop_res:
372+
# cv2.imwrite(os.path.join(self.crop_res_save_dir, f"{fn}_crop_{i}.jpg"), cropped_img)
373+
# show_imgs(crops, is_bgr_img=False)
374+
375+
# recognize cropped images
376+
rs = time()
377+
if self.is_concat:
378+
crops = [concat_crops(crops)]
379+
rec_res_all_crops = self.predict_rec(crops)
380+
time_profile["rec"] = time() - rs
381+
382+
logger.info(
383+
"Recognized texts: \n"
384+
+ "\n".join([f"{text}\t{score}" for text, score in rec_res_all_crops])
385+
+ f"\nRec time: {time_profile['rec']}"
386+
)
387+
388+
# filter out low-score texts and merge detection and recognition results
389+
boxes, text_scores = [], []
390+
for i in range(len(polys)):
391+
box = det_res["polys"][i]
392+
if self.is_concat:
393+
text = rec_res_all_crops[0][0]
394+
text_score = rec_res_all_crops[0][1]
395+
else:
396+
text = rec_res_all_crops[i][0]
397+
text_score = rec_res_all_crops[i][1]
398+
399+
if text_score >= 0.5:
400+
boxes.append(box)
401+
text_scores.append((text, text_score))
402+
time_profile["all"] = time() - start
403+
return boxes, text_scores, time_profile
404+
405+
406+
def predict_both_step(args, det_cfg, rec_cfg):
407+
# parse args
408+
set_logger(name="mindocr")
409+
pred_sys = Predict_System(det_cfg=det_cfg, rec_cfg=rec_cfg, is_concat=args.is_concat)
410+
output_save_dir = det_cfg.predict.output_save_dir or "./output"
411+
img_paths = get_image_paths(args.image_dir)
412+
413+
set_context(mode=det_cfg.system.mode)
414+
415+
tot_time = {} # {'det': 0, 'rec': 0, 'all': 0}
416+
boxes_all, text_scores_all = [], []
417+
for i, img_path in enumerate(img_paths):
418+
logger.info(f"Infering [{i+1}/{len(img_paths)}]: {img_path}")
419+
boxes, text_scores, time_prof = pred_sys.predict(img_path)
420+
boxes_all.append(boxes)
421+
text_scores_all.append(text_scores)
422+
423+
for k in time_prof:
424+
if k not in tot_time:
425+
tot_time[k] = time_prof[k]
426+
else:
427+
tot_time[k] += time_prof[k]
428+
429+
fps = len(img_paths) / tot_time["all"]
430+
logger.info(f"Total time:{tot_time['all']}")
431+
logger.info(f"Average FPS: {fps}")
432+
avg_time = {k: tot_time[k] / len(img_paths) for k in tot_time}
433+
logger.info(f"Averge time cost: {avg_time}")
434+
435+
# save result
436+
save_res(boxes_all, text_scores_all, img_paths, save_path=os.path.join(output_save_dir, "system_results.txt"))
437+
logger.info(f"Done! Results saved in {os.path.join(output_save_dir, 'system_results.txt')}")
299438

300439

301440
def create_parser():
@@ -314,6 +453,7 @@ def create_parser():
314453
default="configs/rec/crnn/crnn_resnet34.yaml",
315454
help='YAML config file specifying default arguments for rec (default="configs/rec/crnn/crnn_resnet34.yaml")',
316455
)
456+
parser.add_argument("--is_concat", type=str2bool, default=False, help="image path or image directory")
317457
parser.add_argument(
318458
"-o",
319459
"--opt",
@@ -323,7 +463,9 @@ def create_parser():
323463
)
324464
# modelarts
325465
group = parser.add_argument_group("modelarts")
326-
group.add_argument("--enable_modelarts", type=bool, default=False, help="Run on modelarts platform (default=False)")
466+
group.add_argument(
467+
"--enable_modelarts", type=str2bool, default=False, help="Run on modelarts platform (default=False)"
468+
)
327469
group.add_argument(
328470
"--device_target",
329471
type=str,
@@ -337,8 +479,6 @@ def create_parser():
337479
group.add_argument("--pretrain_url", type=str, default="", help="pre_train_model paths in obs")
338480
group.add_argument("--train_url", type=str, default="", help="model folder to save/load")
339481

340-
# args = parser.parse_args()
341-
342482
return parser
343483

344484

@@ -378,4 +518,4 @@ def parse_args_and_config():
378518
elif args.task_mode == "system":
379519
rec_cfg = Dict(rec_cfg)
380520
det_cfg = Dict(det_cfg)
381-
predict_system(args, det_cfg, rec_cfg)
521+
predict_both_step(args, det_cfg, rec_cfg)

0 commit comments

Comments
 (0)