Skip to content

Commit 5120a2a

Browse files
authored
1.bug fixes for online inference (#648)
2.introduce online inference based on yaml.
1 parent c877866 commit 5120a2a

16 files changed

+289
-113
lines changed

configs/det/dbnet/db_r50_icdar15.yaml

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ eval:
158158

159159
predict:
160160
ckpt_load_path: tmp_det/best.ckpt
161+
output_save_dir: ./output
161162
dataset_sink_mode: False
162163
dataset:
163164
type: PredictDataset
@@ -169,24 +170,24 @@ predict:
169170
- DecodeImage:
170171
img_mode: RGB
171172
to_float32: False
172-
# - DetLabelEncode:
173-
- DetResize: # GridResize 32
174-
target_size: [ 736, 1280 ]
175-
keep_ratio: False
176-
limit_type: none
177-
divisor: 32
173+
keep_ori: True
174+
- DetResize:
175+
keep_ratio: True
176+
padding: False
177+
limit_side_len: 960
178+
limit_type: max
178179
- NormalizeImage:
179180
bgr_to_rgb: False
180181
is_hwc: True
181182
mean: imagenet
182183
std: imagenet
183184
- ToCHWImage:
184185
# the order of the dataloader list, matching the network input and the labels for evaluation
185-
output_columns: [ 'img_path', 'image', 'raw_img_shape' ] # shape in h, w order
186-
# num_keys_of_labels: 2 # num labels
186+
output_columns: ["image", "img_path", "shape_list", "image_ori"]
187+
net_input_column_index: [ 0 ] # input indices for network forward func in output_columns
187188

188189
loader:
189190
shuffle: False
190-
batch_size: 1 # TODO: due to dynamic shape of polygons (num of boxes varies), BS has to be 1
191+
batch_size: 1
191192
drop_remainder: False
192193
num_workers: 2

configs/rec/crnn/crnn_resnet34.yaml

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -162,30 +162,23 @@ predict:
162162
shuffle: False
163163
transform_pipeline:
164164
- DecodeImage:
165-
img_mode: BGR
165+
img_mode: RGB
166166
to_float32: False
167-
# - RecCTCLabelEncode:
168-
# max_text_len: *max_text_len
169-
# character_dict_path: *character_dict_path
170-
# use_space_char: *use_space_char
171-
# lower: True
172-
- RecResizeImg: # different from paddle (paddle converts image from HWC to CHW and rescale to [-1, 1] after resize.
173-
image_shape: [32, 100] # H, W
174-
infer_mode: *infer_mode
175-
character_dict_path: *character_dict_path
176-
padding: False # aspect ratio will be preserved if true.
177-
- NormalizeImage: # different from paddle (paddle wrongly normalize BGR image with RGB mean/std from ImageNet for det, and simple rescale to [-1, 1] in rec.
178-
bgr_to_rgb: True
179-
is_hwc: True
180-
mean : [127.0, 127.0, 127.0]
181-
std : [127.0, 127.0, 127.0]
167+
- RecResizeNormForInfer:
168+
target_height: 32
169+
target_width: 100
170+
keep_ratio: False
171+
padding: False
172+
norm_before_pad: False
182173
- ToCHWImage:
183174
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visaulize
184-
output_columns: [ 'img_path', 'image', 'raw_img_shape' ]
175+
output_columns: ['image', 'img_path']
176+
net_input_column_index: [0] # input indices for network forward func in output_columns
177+
# label_column_index: [1, 2] # input indices marked as label
185178

186179
loader:
187180
shuffle: False # TODO: tbc
188-
batch_size: 1
181+
batch_size: 2
189182
drop_remainder: True
190183
max_rowsize: 12
191184
num_workers: 8

mindocr/data/predict_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(
4343
raise ValueError("No transform pipeline is specified!")
4444

4545
# prefetch the data keys, to fit GeneratorDataset
46-
_data = self.data_list[0]
46+
_data = self.data_list[0].copy()
4747
_data = run_transforms(_data, transforms=self.transforms)
4848
_available_keys = list(_data.keys())
4949
if output_columns is None:
@@ -60,7 +60,7 @@ def __init__(
6060
)
6161

6262
def __getitem__(self, index):
63-
data = self.data_list[index]
63+
data = self.data_list[index].copy()
6464

6565
# perform transformation on data
6666
data = run_transforms(data, transforms=self.transforms)

mindocr/models/cls_mv3.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ def __init__(self, config):
2727

2828

2929
@register_model
30-
def cls_mobilenet_v3_small_100_model(pretrained=False, **kwargs):
31-
pretrained_backbone = not pretrained
30+
def cls_mobilenet_v3_small_100_model(pretrained=False, pretrained_backbone=True, **kwargs):
3231
model_config = {
3332
"backbone": {
3433
'name': 'cls_mobilenet_v3_small_100',

mindocr/models/det_dbnet.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def __init__(self, config):
3333

3434

3535
@register_model
36-
def dbnet_mobilenetv3(pretrained=False, **kwargs):
37-
pretrained_backbone = 'https://download.mindspore.cn/toolkits/mindcv/mobilenet/mobilenetv3' \
36+
def dbnet_mobilenetv3(pretrained=False, pretrained_backbone=True, **kwargs):
37+
backbone_ckpt_url = 'https://download.mindspore.cn/toolkits/mindcv/mobilenet/mobilenetv3' \
3838
'/mobilenet_v3_large_050_no_scale_se_v2_expand-3c4047ac.ckpt'
3939
model_config = {
4040
"backbone": {
@@ -43,7 +43,7 @@ def dbnet_mobilenetv3(pretrained=False, **kwargs):
4343
'alpha': 0.5,
4444
'out_stages': [5, 8, 14, 20],
4545
'bottleneck_params': {'se_version': 'SqueezeExciteV2', 'always_expand': True},
46-
'pretrained': pretrained_backbone if not pretrained else False
46+
'pretrained': backbone_ckpt_url if pretrained_backbone else False
4747
},
4848
"neck": {
4949
"name": 'DBFPN',
@@ -68,8 +68,7 @@ def dbnet_mobilenetv3(pretrained=False, **kwargs):
6868

6969

7070
@register_model
71-
def dbnet_resnet18(pretrained=False, **kwargs):
72-
pretrained_backbone = not pretrained
71+
def dbnet_resnet18(pretrained=False, pretrained_backbone=True, **kwargs):
7372
model_config = {
7473
"backbone": {
7574
'name': 'det_resnet18',
@@ -98,8 +97,7 @@ def dbnet_resnet18(pretrained=False, **kwargs):
9897

9998

10099
@register_model
101-
def dbnet_resnet50(pretrained=False, **kwargs):
102-
pretrained_backbone = not pretrained
100+
def dbnet_resnet50(pretrained=False, pretrained_backbone=True, **kwargs):
103101
model_config = {
104102
"backbone": {
105103
'name': 'det_resnet50',
@@ -128,8 +126,7 @@ def dbnet_resnet50(pretrained=False, **kwargs):
128126

129127

130128
@register_model
131-
def dbnetpp_resnet50(pretrained=False, **kwargs):
132-
pretrained_backbone = not pretrained
129+
def dbnetpp_resnet50(pretrained=False, pretrained_backbone=True, **kwargs):
133130
model_config = {
134131
"backbone": {
135132
'name': 'det_resnet50',

mindocr/models/det_psenet.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ def __init__(self, config):
2929

3030

3131
@register_model
32-
def psenet_resnet152(pretrained=False, **kwargs):
33-
pretrained_backbone = not pretrained
32+
def psenet_resnet152(pretrained=False, pretrained_backbone=True, **kwargs):
3433
model_config = {
3534
"backbone": {
3635
'name': 'det_resnet152',

mindocr/models/kie_layoutxlm.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,18 @@ def construct(self, x):
4040

4141

4242
@register_model
43-
def layoutxlm_ser(pretrained: bool = True, use_visual_backbone: bool = True, use_float16: bool = False, **kwargs):
43+
def layoutxlm_ser(
44+
pretrained: bool = True,
45+
pretrained_backbone=False,
46+
use_visual_backbone: bool = True,
47+
use_float16: bool = False,
48+
**kwargs
49+
):
4450
model_config = {
4551
"type": "kie",
4652
"backbone": {
4753
"name": "layoutxlm",
48-
"pretrained": pretrained, # backbone pretrained
54+
"pretrained": pretrained_backbone, # backbone pretrained
4955
"use_visual_backbone": use_visual_backbone,
5056
"use_float16": use_float16,
5157
},

tests/ut/test_mindir_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_mindir_infer(model_name):
3737
outputs_mindir = model(x)
3838

3939
# get original ckpt outputs
40-
net = build_model(model_name, pretrained=True)
40+
net = build_model(model_name, pretrained=True, pretrained_backbone=False)
4141
outputs_ckpt = net(x)
4242

4343
for i, o in enumerate(outputs_mindir):

tests/ut/test_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
@pytest.mark.parametrize("pretrained", [True, False])
3232
def test_model_by_name(model_name, pretrained):
3333
print(model_name)
34-
build_model(model_name, pretrained=pretrained)
34+
pretrained_backbone = not pretrained
35+
build_model(model_name, pretrained=pretrained, pretrained_backbone=pretrained_backbone)
3536
print("model created")
3637

3738

tools/export.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,11 @@ def export(model_name_or_config, data_shape, local_ckpt_path, save_dir, is_dynam
9191
amp_level = "O0"
9292

9393
if local_ckpt_path:
94-
net = build_model(model_cfg, pretrained=False, ckpt_load_path=local_ckpt_path, amp_level=amp_level)
94+
net = build_model(
95+
model_cfg, pretrained=False, pretrained_backbone=False, ckpt_load_path=local_ckpt_path, amp_level=amp_level
96+
)
9597
else:
96-
net = build_model(model_cfg, pretrained=True, amp_level=amp_level)
98+
net = build_model(model_cfg, pretrained=True, pretrained_backbone=False, amp_level=amp_level)
9799

98100
logger.info(f"Set the AMP level of the model to be `{amp_level}`.")
99101

tools/infer/text/config.py

Lines changed: 4 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
"""
66
import argparse
77

8-
import yaml
9-
108

119
def str2bool(v):
1210
if isinstance(v, bool):
@@ -20,24 +18,9 @@ def str2bool(v):
2018

2119

2220
def create_parser():
23-
parser_config = argparse.ArgumentParser(description="Inference Config File", add_help=False)
24-
parser_config.add_argument(
25-
"-c", "--config", type=str, default="", help='YAML config file specifying default arguments (default="")'
26-
)
27-
2821
parser = argparse.ArgumentParser(description="Inference Config Args")
2922
# params for prediction engine
3023
parser.add_argument("--mode", type=int, default=0, help="0 for graph mode, 1 for pynative mode ") # added
31-
# parser.add_argument("--use_gpu", type=str2bool, default=True)
32-
# parser.add_argument("--use_npu", type=str2bool, default=False)
33-
# parser.add_argument("--ir_optim", type=str2bool, default=True)
34-
# parser.add_argument("--min_subgraph_size", type=int, default=15)
35-
# parser.add_argument("--precision", type=str, default="fp32")
36-
# parser.add_argument("--gpu_mem", type=int, default=500)
37-
# parser.add_argument("--gpu_id", type=int, default=0)
38-
39-
parser.add_argument("--det_model_config", type=str, help="path to det model yaml config") # added
40-
parser.add_argument("--rec_model_config", type=str, help="path to rec model yaml config") # added
4124

4225
# params for text detector
4326
parser.add_argument("--image_dir", type=str, help="image path or image directory")
@@ -165,21 +148,6 @@ def create_parser():
165148
help="Whether to visualize results and save the visualized image.",
166149
)
167150

168-
# multi-process
169-
"""
170-
parser.add_argument("--use_mp", type=str2bool, default=False)
171-
parser.add_argument("--total_process_num", type=int, default=1)
172-
parser.add_argument("--process_id", type=int, default=0)
173-
174-
parser.add_argument("--benchmark", type=str2bool, default=False)
175-
parser.add_argument("--save_log_path", type=str, default="./log_output/")
176-
177-
parser.add_argument("--show_log", type=str2bool, default=True)
178-
parser.add_argument("--use_onnx", type=str2bool, default=False)
179-
180-
parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
181-
parser.add_argument("--cpu_threads", type=int, default=10)
182-
"""
183151
parser.add_argument("--warmup", type=str2bool, default=False)
184152
parser.add_argument("--ocr_result_dir", type=str, default=None, help="path or directory of ocr results")
185153
parser.add_argument(
@@ -203,29 +171,10 @@ def create_parser():
203171
)
204172
parser.add_argument("--kie_batch_num", type=int, default=8)
205173

206-
return parser_config, parser
207-
208-
209-
def _check_cfgs_in_parser(cfgs: dict, parser: argparse.ArgumentParser):
210-
actions_dest = [action.dest for action in parser._actions]
211-
defaults_key = parser._defaults.keys()
212-
for k in cfgs.keys():
213-
if k not in actions_dest and k not in defaults_key:
214-
raise KeyError(f"{k} does not exist in ArgumentParser!")
215-
174+
return parser
216175

217-
def parse_args(args=None):
218-
parser_config, parser = create_parser()
219-
# Do we have a config file to parse?
220-
args_config, remaining = parser_config.parse_known_args(args)
221-
if args_config.config:
222-
with open(args_config.config, "r") as f:
223-
cfg = yaml.safe_load(f)
224-
_check_cfgs_in_parser(cfg, parser)
225-
parser.set_defaults(**cfg)
226-
parser.set_defaults(config=args_config.config)
227176

228-
# The main arg parser parses the rest of the args, the usual
229-
# defaults will have been overridden if config file specified.
230-
args = parser.parse_args(remaining)
177+
def parse_args():
178+
parser = create_parser()
179+
args = parser.parse_args()
231180
return args

tools/infer/text/postprocess.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
class Postprocessor(object):
13-
def __init__(self, task="det", algo="DB", **kwargs):
13+
def __init__(self, task="det", algo="DB", rec_char_dict_path=None, **kwargs):
1414
# algo = algo.lower()
1515
if task == "det":
1616
if algo.startswith("DB"):
@@ -46,27 +46,33 @@ def __init__(self, task="det", algo="DB", **kwargs):
4646
self.rescale_internally = True
4747
self.round = True
4848
elif task == "rec":
49+
rec_char_dict_path = (
50+
rec_char_dict_path or "mindocr/utils/dict/ch_dict.txt"
51+
if algo in ["CRNN_CH", "SVTR_PPOCRv3_CH"]
52+
else rec_char_dict_path
53+
)
4954
# TODO: update character_dict_path and use_space_char after CRNN trained using en_dict.txt released
5055
if algo.startswith("CRNN") or algo.startswith("SVTR"):
5156
# TODO: allow users to input char dict path
52-
dict_path = "mindocr/utils/dict/ch_dict.txt" if algo in ["CRNN_CH", "SVTR_PPOCRv3_CH"] else None
5357
if algo == "SVTR_PPOCRv3_CH":
5458
postproc_cfg = dict(
5559
name="CTCLabelDecode",
56-
character_dict_path=dict_path,
60+
character_dict_path=rec_char_dict_path,
5761
use_space_char=True,
5862
)
5963
else:
6064
postproc_cfg = dict(
6165
name="RecCTCLabelDecode",
62-
character_dict_path=dict_path,
66+
character_dict_path=rec_char_dict_path,
6367
use_space_char=False,
6468
)
6569
elif algo.startswith("RARE"):
66-
dict_path = "mindocr/utils/dict/ch_dict.txt" if algo == "RARE_CH" else None
70+
rec_char_dict_path = (
71+
rec_char_dict_path or "mindocr/utils/dict/ch_dict.txt" if algo == "RARE_CH" else rec_char_dict_path
72+
)
6773
postproc_cfg = dict(
6874
name="RecAttnLabelDecode",
69-
character_dict_path=dict_path,
75+
character_dict_path=rec_char_dict_path,
7076
use_space_char=False,
7177
)
7278

tools/infer/text/predict_det.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,13 @@ def __init__(self, args):
6262
"The program has switched to amp_level O2 automatically."
6363
)
6464
amp_level = "O2"
65-
self.model = build_model(model_name, pretrained=pretrained, ckpt_load_path=ckpt_load_path, amp_level=amp_level)
65+
self.model = build_model(
66+
model_name,
67+
pretrained=pretrained,
68+
pretrained_backbone=False,
69+
ckpt_load_path=ckpt_load_path,
70+
amp_level=amp_level,
71+
)
6672
self.model.set_train(False)
6773
logger.info(
6874
"Init detection model: {} --> {}. Model weights loaded from {}".format(

0 commit comments

Comments
 (0)