Skip to content

Commit 510dd16

Browse files
authored
Merge pull request #59 from RapidAI/optim_cls_model
Optim cls model
2 parents 87034f1 + 9ea4a74 commit 510dd16

File tree

4 files changed

+153
-32
lines changed

4 files changed

+153
-32
lines changed

table_cls/main.py

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,32 +3,47 @@
33

44
import cv2
55
import numpy as np
6-
import onnxruntime
76
from PIL import Image
87

9-
from .utils import InputType, LoadImage
8+
from .utils import InputType, LoadImage, OrtInferSession, ResizePad
109

1110
cur_dir = Path(__file__).resolve().parent
12-
table_cls_model_path = cur_dir / "models" / "table_cls.onnx"
11+
q_cls_model_path = cur_dir / "models" / "table_cls.onnx"
12+
yolo_cls_model_path = cur_dir / "models" / "yolo_cls.onnx"
1313

1414

1515
class TableCls:
16-
def __init__(self, device="cpu"):
17-
providers = (
18-
["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"]
19-
)
20-
self.table_cls = onnxruntime.InferenceSession(
21-
table_cls_model_path, providers=providers
22-
)
16+
def __init__(self, model_type="yolo", model_path=yolo_cls_model_path):
17+
if model_type == "yolo":
18+
self.table_engine = YoloCls(model_path)
19+
else:
20+
model_path = q_cls_model_path
21+
self.table_engine = QanythingCls(model_path)
22+
self.load_img = LoadImage()
23+
24+
def __call__(self, content: InputType):
25+
ss = time.perf_counter()
26+
img = self.load_img(content)
27+
img = self.table_engine.preprocess(img)
28+
predict_cla = self.table_engine([img])
29+
table_elapse = time.perf_counter() - ss
30+
return predict_cla, table_elapse
31+
32+
33+
class QanythingCls:
34+
def __init__(self, model_path):
35+
self.table_cls = OrtInferSession(model_path)
2336
self.inp_h = 224
2437
self.inp_w = 224
2538
self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
2639
self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
2740
self.cls = {0: "wired", 1: "wireless"}
28-
self.load_img = LoadImage()
2941

30-
def _preprocess(self, image):
31-
img = Image.fromarray(np.uint8(image))
42+
def preprocess(self, img):
43+
img = cv2.cvtColor(img.copy(), cv2.COLOR_BGR2RGB)
44+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
45+
img = np.stack((img,) * 3, axis=-1)
46+
img = Image.fromarray(np.uint8(img))
3247
img = img.resize((self.inp_h, self.inp_w))
3348
img = np.array(img, dtype=np.float32) / 255.0
3449
img -= self.mean
@@ -37,15 +52,27 @@ def _preprocess(self, image):
3752
img = np.expand_dims(img, axis=0) # Add batch dimension, only one image
3853
return img
3954

40-
def __call__(self, content: InputType):
41-
ss = time.perf_counter()
42-
img = self.load_img(content)
43-
gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
44-
gray_img = np.stack((gray_img,) * 3, axis=-1)
45-
gray_img = self._preprocess(gray_img)
46-
output = self.table_cls.run(None, {"input": gray_img})
55+
def __call__(self, img):
56+
output = self.table_cls(img)
4757
predict = np.exp(output[0] - np.max(output[0], axis=1, keepdims=True))
4858
predict /= np.sum(predict, axis=1, keepdims=True)
4959
predict_cla = np.argmax(predict, axis=1)[0]
50-
table_elapse = time.perf_counter() - ss
51-
return self.cls[predict_cla], table_elapse
60+
return self.cls[predict_cla]
61+
62+
63+
class YoloCls:
64+
def __init__(self, model_path):
65+
self.table_cls = OrtInferSession(model_path)
66+
self.cls = {0: "wireless", 1: "wired"}
67+
68+
def preprocess(self, img):
69+
img, *_ = ResizePad(img, 640)
70+
img = np.array(img, dtype=np.float32) / 255.0
71+
img = img.transpose(2, 0, 1) # HWC to CHW
72+
img = np.expand_dims(img, axis=0) # Add batch dimension, only one image
73+
return img
74+
75+
def __call__(self, img):
76+
output = self.table_cls(img)
77+
predict_cla = np.argmax(output[0], axis=1)[0]
78+
return self.cls[predict_cla]

table_cls/utils.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,86 @@
1+
import traceback
12
from io import BytesIO
23
from pathlib import Path
3-
from typing import Union
4+
from typing import Union, List
45

56
import cv2
67
import numpy as np
78
from PIL import Image, UnidentifiedImageError
9+
from onnxruntime import InferenceSession
10+
from onnxruntime.capi.onnxruntime_pybind11_state import (
11+
SessionOptions,
12+
GraphOptimizationLevel,
13+
)
814

915
InputType = Union[str, np.ndarray, bytes, Path, Image.Image]
1016

1117

18+
class OrtInferSession:
19+
def __init__(self, model_path: Union[str, Path], num_threads: int = -1):
20+
self.verify_exist(model_path)
21+
22+
self.num_threads = num_threads
23+
self._init_sess_opt()
24+
25+
cpu_ep = "CPUExecutionProvider"
26+
cpu_provider_options = {
27+
"arena_extend_strategy": "kSameAsRequested",
28+
}
29+
EP_list = [(cpu_ep, cpu_provider_options)]
30+
try:
31+
self.session = InferenceSession(
32+
str(model_path), sess_options=self.sess_opt, providers=EP_list
33+
)
34+
except TypeError:
35+
# 这里兼容ort 1.5.2
36+
self.session = InferenceSession(str(model_path), sess_options=self.sess_opt)
37+
38+
def _init_sess_opt(self):
39+
self.sess_opt = SessionOptions()
40+
self.sess_opt.log_severity_level = 4
41+
self.sess_opt.enable_cpu_mem_arena = False
42+
43+
if self.num_threads != -1:
44+
self.sess_opt.intra_op_num_threads = self.num_threads
45+
46+
self.sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
47+
48+
def __call__(self, input_content: List[np.ndarray]) -> np.ndarray:
49+
input_dict = dict(zip(self.get_input_names(), input_content))
50+
try:
51+
return self.session.run(None, input_dict)
52+
except Exception as e:
53+
error_info = traceback.format_exc()
54+
raise ONNXRuntimeError(error_info) from e
55+
56+
def get_input_names(
57+
self,
58+
):
59+
return [v.name for v in self.session.get_inputs()]
60+
61+
def get_output_name(self, output_idx=0):
62+
return self.session.get_outputs()[output_idx].name
63+
64+
def get_metadata(self):
65+
meta_dict = self.session.get_modelmeta().custom_metadata_map
66+
return meta_dict
67+
68+
@staticmethod
69+
def verify_exist(model_path: Union[Path, str]):
70+
if not isinstance(model_path, Path):
71+
model_path = Path(model_path)
72+
73+
if not model_path.exists():
74+
raise FileNotFoundError(f"{model_path} does not exist!")
75+
76+
if not model_path.is_file():
77+
raise FileExistsError(f"{model_path} must be a file")
78+
79+
80+
class ONNXRuntimeError(Exception):
81+
pass
82+
83+
1284
class LoadImageError(Exception):
1385
pass
1486

@@ -106,3 +178,19 @@ def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
106178
def verify_exist(file_path: Union[str, Path]):
107179
if not Path(file_path).exists():
108180
raise LoadImageError(f"{file_path} does not exist.")
181+
182+
183+
def ResizePad(img, target_size):
184+
h, w = img.shape[:2]
185+
m = max(h, w)
186+
ratio = target_size / m
187+
new_w, new_h = int(ratio * w), int(ratio * h)
188+
img = cv2.resize(img, (new_w, new_h), cv2.INTER_LINEAR)
189+
top = (target_size - new_h) // 2
190+
bottom = (target_size - new_h) - top
191+
left = (target_size - new_w) // 2
192+
right = (target_size - new_w) - left
193+
img1 = cv2.copyMakeBorder(
194+
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
195+
)
196+
return img1, new_w, new_h, left, top

tests/test_table_cls.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
@pytest.mark.parametrize(
1717
"img_path, expected",
18-
[("wired_table.png", "wired"), ("lineless_table.png", "wireless")],
18+
[("wired_table.jpg", "wired"), ("lineless_table.png", "wireless")],
1919
)
2020
def test_input_normal(img_path, expected):
2121
img_path = test_file_dir / img_path

wired_table_rec/table_recover.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,20 @@ def get_benchmark_cols(
9393
sorted(range_res.items(), key=lambda x: x[0], reverse=True)
9494
)
9595
for k, v in sorted_res.items():
96-
if not all(v):
97-
continue
98-
99-
longest_x = np.insert(longest_x, v[1], cur_row[k])
100-
longest_col_points = np.insert(
101-
longest_col_points, v[1], polygons[row_value[k]], axis=0
102-
)
103-
96+
# bugfix: https://github.com/RapidAI/TableStructureRec/discussions/55
97+
# 最长列不包含第一列和最后一列的场景需要兼容
98+
if all(v) or v[1] == 0:
99+
longest_x = np.insert(longest_x, v[1], cur_row[k])
100+
longest_col_points = np.insert(
101+
longest_col_points, v[1], polygons[row_value[k]], axis=0
102+
)
103+
elif v[0] and v[0] + 1 == len(longest_x):
104+
longest_x = np.append(longest_x, cur_row[k])
105+
longest_col_points = np.append(
106+
longest_col_points,
107+
polygons[row_value[k]][np.newaxis, :, :],
108+
axis=0,
109+
)
104110
# 求出最右侧所有cell的宽,其中最小的作为最后一列宽度
105111
rightmost_idxs = [v[-1] for v in rows.values()]
106112
rightmost_boxes = polygons[rightmost_idxs]

0 commit comments

Comments
 (0)