9
9
import logging
10
10
import os
11
11
import sys
12
+ from time import time
12
13
14
+ import cv2
15
+ import numpy as np
13
16
import yaml
14
17
from addict import Dict
15
18
from PIL import Image
22
25
from mindspore import Tensor , get_context , set_auto_parallel_context , set_context
23
26
from mindspore .communication import get_group_size , get_rank , init
24
27
28
+ from deploy .py_infer .src .infer_args import str2bool # noqa
25
29
from mindocr .data import build_dataset
26
30
from mindocr .data .transforms import create_transforms , run_transforms
27
31
from mindocr .models import build_model
28
32
from mindocr .postprocess import build_postprocess
29
33
from mindocr .utils .visualize import draw_boxes , show_imgs
30
34
from tools .arg_parser import _merge_options , _parse_options
35
+ from tools .infer .text .utils import get_image_paths
31
36
from tools .modelarts_adapter .modelarts import modelarts_setup
32
37
33
38
__dir__ = os .path .dirname (os .path .abspath (__file__ ))
@@ -155,21 +160,7 @@ def predict_single_step(cfg, save_res=True):
155
160
)
156
161
157
162
# 3.Build model
158
- amp_level = cfg .system .get ("amp_level_infer" , "O0" )
159
- if get_context ("device_target" ) == "GPU" and amp_level == "O3" :
160
- logger .warning (
161
- "Model evaluation does not support amp_level O3 on GPU currently. "
162
- "The program has switched to amp_level O2 automatically."
163
- )
164
- amp_level = "O2"
165
- cfg .model .backbone .pretrained = False
166
- if cfg .predict .ckpt_load_path is None :
167
- logger .warning (
168
- f"No ckpt is available for { cfg .model .task } , "
169
- "please check your configuration of 'predict.ckpt_load_path' in the yaml file."
170
- )
171
- network = build_model (cfg .model , ckpt_load_path = cfg .predict .ckpt_load_path , amp_level = amp_level )
172
- network .set_train (False )
163
+ network = build_model_from_config (cfg )
173
164
174
165
# 4.Build postprocessor for network output
175
166
postprocessor = build_postprocess (cfg .postprocess )
@@ -230,72 +221,220 @@ def predict_single_step(cfg, save_res=True):
230
221
return preds_list
231
222
232
223
233
- def predict_system (args , det_cfg , rec_cfg ):
234
- """Run predict for both det and rec task"""
235
- # merge image_dir option in model config
236
- det_cfg .predict .dataset .dataset_root = ""
237
- det_cfg .predict .dataset .data_dir = args .image_dir
238
- output_save_dir = det_cfg .predict .output_save_dir or "./output"
239
-
240
- # get det result from predict
241
- preds_list = predict_single_step (det_cfg , save_res = False )
242
-
243
- # set amp level
244
- amp_level = det_cfg .system .get ("amp_level_infer" , "O0" )
224
+ def build_model_from_config (cfg ):
225
+ amp_level = cfg .system .get ("amp_level_infer" , "O0" )
245
226
if get_context ("device_target" ) == "GPU" and amp_level == "O3" :
246
227
logger .warning (
247
228
"Model evaluation does not support amp_level O3 on GPU currently. "
248
229
"The program has switched to amp_level O2 automatically."
249
230
)
250
231
amp_level = "O2"
251
-
252
- # create preprocess and postprocess for rec task
253
- transforms = create_transforms (rec_cfg .predict .dataset .transform_pipeline )
254
- postprocessor = build_postprocess (rec_cfg .postprocess )
255
-
256
- # build rec model from yaml
257
- rec_cfg .model .backbone .pretrained = False
258
- if rec_cfg .predict .ckpt_load_path is None :
232
+ cfg .model .backbone .pretrained = False
233
+ if cfg .predict .ckpt_load_path is None :
259
234
logger .warning (
260
- f"No ckpt is available for { rec_cfg .model .type } , "
235
+ f"No ckpt is available for { cfg .model .task } , "
261
236
"please check your configuration of 'predict.ckpt_load_path' in the yaml file."
262
237
)
263
- rec_network = build_model (rec_cfg .model , ckpt_load_path = rec_cfg .predict .ckpt_load_path , amp_level = amp_level )
264
-
265
- # start rec task
266
- logger .info ("Start rec" )
267
- img_list = [] # list of img_path
268
- boxes_all = [] # list of boxes of all image
269
- text_scores_all = [] # list of text and scores of all image
270
- for preds_batch in tqdm (preds_list ):
271
- # preds_batch is a dictionary of det prediction output, which contains det information of a batch
272
- preds_batch ["texts" ] = []
273
- preds_batch ["confs" ] = []
274
- for i , crops in enumerate (preds_batch ["crops" ]):
275
- # A batch may contain multiple images
276
- img_path = preds_batch ["img_path" ][i ]
277
- img_box = []
278
- img_text_scores = []
279
- for j , crop in enumerate (crops ):
280
- # For each image, it may contain several crops
281
- data = {"image" : crop }
282
- data ["image_ori" ] = crop .copy ()
283
- data ["image_shape" ] = crop .shape
284
- data = run_transforms (data , transforms [1 :])
285
- data = rec_network (Tensor (data ["image" ]).expand_dims (0 ))
286
- out = postprocessor (data )
287
- confs = out ["confs" ][0 ]
288
- if confs > 0.5 :
289
- # Keep text with a confidence greater than 0.5
290
- box = preds_batch ["polys" ][i ][j ]
291
- text = out ["texts" ][0 ]
292
- img_box .append (box )
293
- img_text_scores .append ((text , confs ))
294
- # Each image saves its path, box and texts_scores
295
- img_list .append (img_path )
296
- boxes_all .append (img_box )
297
- text_scores_all .append (img_text_scores )
298
- save_res (boxes_all , text_scores_all , img_list , save_path = os .path .join (output_save_dir , "system_results.txt" ))
238
+ network = build_model (cfg .model , ckpt_load_path = cfg .predict .ckpt_load_path , amp_level = amp_level )
239
+ network .set_train (False )
240
+ return network
241
+
242
+
243
+ def sort_polys (polys ):
244
+ return sorted (polys , key = lambda points : (points [0 ][1 ], points [0 ][0 ]))
245
+
246
+
247
+ def concat_crops (crops : list ):
248
+ max_height = max (crop .shape [0 ] for crop in crops )
249
+ resized_crops = []
250
+ for crop in crops :
251
+ h , w , c = crop .shape
252
+ new_h = max_height
253
+ new_w = int ((w / h ) * new_h )
254
+
255
+ resized_img = cv2 .resize (crop , (new_w , new_h ), interpolation = cv2 .INTER_LINEAR )
256
+ resized_crops .append (resized_img )
257
+ crops = np .concatenate (resized_crops , axis = 1 )
258
+ return crops
259
+
260
+
261
+ class Predict_System :
262
+ def __init__ (self , det_cfg , rec_cfg , is_concat = False ):
263
+ for transform in det_cfg .predict .dataset .transform_pipeline :
264
+ if "DecodeImage" in transform :
265
+ transform ["DecodeImage" ].update ({"keep_ori" : True })
266
+ break
267
+ self .det_transforms = create_transforms (det_cfg .predict .dataset .transform_pipeline )
268
+ self .det_model = build_model_from_config (det_cfg )
269
+ self .det_postprocess = build_postprocess (det_cfg .postprocess )
270
+
271
+ self .rec_batch_size = rec_cfg .predict .loader .batch_size
272
+ self .rec_preprocess = create_transforms (rec_cfg .predict .dataset .transform_pipeline )
273
+ self .rec_model = build_model_from_config (rec_cfg )
274
+ self .rec_postprocess = build_postprocess (rec_cfg .postprocess )
275
+
276
+ self .is_concat = is_concat
277
+
278
+ def predict_rec (self , crops : list ):
279
+ """
280
+ Run text recognition serially for input images
281
+
282
+ Args:
283
+ img_or_path_list: list of str for img path or np.array for RGB image
284
+ do_visualize: visualize preprocess and final result and save them
285
+
286
+ Return:
287
+ rec_res: list of tuple, where each tuple is (text, score) - text recognition result for each input image
288
+ in order.
289
+ where text is the predicted text string, score is its confidence score.
290
+ e.g. [('apple', 0.9), ('bike', 1.0)]
291
+ """
292
+ rec_res = []
293
+ num_crops = len (crops )
294
+
295
+ for idx in range (0 , num_crops , self .rec_batch_size ): # batch begin index i
296
+ batch_begin = idx
297
+ batch_end = min (idx + self .rec_batch_size , num_crops )
298
+ logger .info (f"Rec img idx range: [{ batch_begin } , { batch_end } )" )
299
+ # TODO: set max_wh_ratio to the maximum wh ratio of images in the batch. and update it for resize,
300
+ # which may improve recognition accuracy in batch-mode
301
+ # especially for long text image. max_wh_ratio=max(max_wh_ratio, img_w / img_h).
302
+ # The short ones should be scaled with a.r. unchanged and padded to max width in batch.
303
+
304
+ # preprocess
305
+ # TODO: run in parallel with multiprocessing
306
+ img_batch = []
307
+ for j in range (batch_begin , batch_end ): # image index j
308
+ data = run_transforms ({"image" : crops [j ]}, self .rec_preprocess [1 :])
309
+ img_batch .append (data ["image" ])
310
+
311
+ img_batch = np .stack (img_batch ) if len (img_batch ) > 1 else np .expand_dims (img_batch [0 ], axis = 0 )
312
+
313
+ # infer
314
+ net_pred = self .rec_model (Tensor (img_batch ))
315
+
316
+ # postprocess
317
+ batch_res = self .rec_postprocess (net_pred )
318
+ rec_res .extend (list (zip (batch_res ["texts" ], batch_res ["confs" ])))
319
+
320
+ return rec_res
321
+
322
+ def predict (self , img_path ):
323
+ """
324
+ Detect and recognize texts in an image
325
+
326
+ Args:
327
+ img_or_path (str or np.ndarray): path to image or image rgb values as a numpy array
328
+
329
+ Return:
330
+ boxes (list): detected text boxes, in shape [num_boxes, num_points, 2], where the point coordinate (x, y)
331
+ follows: x - horizontal (image width direction), y - vertical (image height)
332
+ texts (list[tuple]): list of (text, score) where text is the recognized text string for each box,
333
+ and score is the confidence score.
334
+ time_profile (dict): record the time cost for each sub-task.
335
+ """
336
+
337
+ time_profile = {}
338
+ start = time ()
339
+
340
+ # detect text regions on an image
341
+ data = {"img_path" : img_path }
342
+ data = run_transforms (data , self .det_transforms )
343
+ input_np = np .expand_dims (data ["image" ], axis = 0 )
344
+ logits = self .det_model (Tensor (input_np ))
345
+ pred = self .det_postprocess (logits , shape_list = np .expand_dims (data ["shape_list" ], axis = 0 ))
346
+ polys = pred ["polys" ][0 ]
347
+ scores = pred ["scores" ][0 ]
348
+ pred = dict (polys = polys , scores = scores )
349
+ det_res = validate_det_res (pred , data ["image_ori" ].shape [:2 ], min_poly_points = 3 , min_area = 3 )
350
+ det_res ["img_ori" ] = data ["image_ori" ]
351
+
352
+ time_profile ["det" ] = time () - start
353
+ polys = det_res ["polys" ].copy ()
354
+ if len (polys ) == 0 :
355
+ logger .warning (f"No text detected in { img_path } " )
356
+ time_profile ["rec" ] = 0.0
357
+ time_profile ["all" ] = time_profile ["det" ]
358
+ return [], [], time_profile
359
+ polys = sort_polys (polys )
360
+ logger .info (f"Num detected text boxes: { len (polys )} \n Det time: { time_profile ['det' ]} " )
361
+ if self .is_concat :
362
+ logger .info ("After concatenating, 1 croped image will be recognized." )
363
+
364
+ # crop text regions
365
+ crops = []
366
+ for i in range (len (polys )):
367
+ poly = polys [i ].astype (np .float32 )
368
+ cropped_img = crop_text_region (data ["image_ori" ], poly , box_type = det_cfg .postprocess .box_type )
369
+ crops .append (cropped_img )
370
+
371
+ # if self.save_crop_res:
372
+ # cv2.imwrite(os.path.join(self.crop_res_save_dir, f"{fn}_crop_{i}.jpg"), cropped_img)
373
+ # show_imgs(crops, is_bgr_img=False)
374
+
375
+ # recognize cropped images
376
+ rs = time ()
377
+ if self .is_concat :
378
+ crops = [concat_crops (crops )]
379
+ rec_res_all_crops = self .predict_rec (crops )
380
+ time_profile ["rec" ] = time () - rs
381
+
382
+ logger .info (
383
+ "Recognized texts: \n "
384
+ + "\n " .join ([f"{ text } \t { score } " for text , score in rec_res_all_crops ])
385
+ + f"\n Rec time: { time_profile ['rec' ]} "
386
+ )
387
+
388
+ # filter out low-score texts and merge detection and recognition results
389
+ boxes , text_scores = [], []
390
+ for i in range (len (polys )):
391
+ box = det_res ["polys" ][i ]
392
+ if self .is_concat :
393
+ text = rec_res_all_crops [0 ][0 ]
394
+ text_score = rec_res_all_crops [0 ][1 ]
395
+ else :
396
+ text = rec_res_all_crops [i ][0 ]
397
+ text_score = rec_res_all_crops [i ][1 ]
398
+
399
+ if text_score >= 0.5 :
400
+ boxes .append (box )
401
+ text_scores .append ((text , text_score ))
402
+ time_profile ["all" ] = time () - start
403
+ return boxes , text_scores , time_profile
404
+
405
+
406
+ def predict_both_step (args , det_cfg , rec_cfg ):
407
+ # parse args
408
+ set_logger (name = "mindocr" )
409
+ pred_sys = Predict_System (det_cfg = det_cfg , rec_cfg = rec_cfg , is_concat = args .is_concat )
410
+ output_save_dir = det_cfg .predict .output_save_dir or "./output"
411
+ img_paths = get_image_paths (args .image_dir )
412
+
413
+ set_context (mode = det_cfg .system .mode )
414
+
415
+ tot_time = {} # {'det': 0, 'rec': 0, 'all': 0}
416
+ boxes_all , text_scores_all = [], []
417
+ for i , img_path in enumerate (img_paths ):
418
+ logger .info (f"Infering [{ i + 1 } /{ len (img_paths )} ]: { img_path } " )
419
+ boxes , text_scores , time_prof = pred_sys .predict (img_path )
420
+ boxes_all .append (boxes )
421
+ text_scores_all .append (text_scores )
422
+
423
+ for k in time_prof :
424
+ if k not in tot_time :
425
+ tot_time [k ] = time_prof [k ]
426
+ else :
427
+ tot_time [k ] += time_prof [k ]
428
+
429
+ fps = len (img_paths ) / tot_time ["all" ]
430
+ logger .info (f"Total time:{ tot_time ['all' ]} " )
431
+ logger .info (f"Average FPS: { fps } " )
432
+ avg_time = {k : tot_time [k ] / len (img_paths ) for k in tot_time }
433
+ logger .info (f"Averge time cost: { avg_time } " )
434
+
435
+ # save result
436
+ save_res (boxes_all , text_scores_all , img_paths , save_path = os .path .join (output_save_dir , "system_results.txt" ))
437
+ logger .info (f"Done! Results saved in { os .path .join (output_save_dir , 'system_results.txt' )} " )
299
438
300
439
301
440
def create_parser ():
@@ -314,6 +453,7 @@ def create_parser():
314
453
default = "configs/rec/crnn/crnn_resnet34.yaml" ,
315
454
help = 'YAML config file specifying default arguments for rec (default="configs/rec/crnn/crnn_resnet34.yaml")' ,
316
455
)
456
+ parser .add_argument ("--is_concat" , type = str2bool , default = False , help = "image path or image directory" )
317
457
parser .add_argument (
318
458
"-o" ,
319
459
"--opt" ,
@@ -323,7 +463,9 @@ def create_parser():
323
463
)
324
464
# modelarts
325
465
group = parser .add_argument_group ("modelarts" )
326
- group .add_argument ("--enable_modelarts" , type = bool , default = False , help = "Run on modelarts platform (default=False)" )
466
+ group .add_argument (
467
+ "--enable_modelarts" , type = str2bool , default = False , help = "Run on modelarts platform (default=False)"
468
+ )
327
469
group .add_argument (
328
470
"--device_target" ,
329
471
type = str ,
@@ -337,8 +479,6 @@ def create_parser():
337
479
group .add_argument ("--pretrain_url" , type = str , default = "" , help = "pre_train_model paths in obs" )
338
480
group .add_argument ("--train_url" , type = str , default = "" , help = "model folder to save/load" )
339
481
340
- # args = parser.parse_args()
341
-
342
482
return parser
343
483
344
484
@@ -378,4 +518,4 @@ def parse_args_and_config():
378
518
elif args .task_mode == "system" :
379
519
rec_cfg = Dict (rec_cfg )
380
520
det_cfg = Dict (det_cfg )
381
- predict_system (args , det_cfg , rec_cfg )
521
+ predict_both_step (args , det_cfg , rec_cfg )
0 commit comments