Skip to content

Commit ad7bf8e

Browse files
committed
feat: add yolo cls model
1 parent bc43533 commit ad7bf8e

File tree

2 files changed

+137
-23
lines changed

2 files changed

+137
-23
lines changed

table_cls/main.py

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,32 +3,46 @@
33

44
import cv2
55
import numpy as np
6-
import onnxruntime
76
from PIL import Image
87

9-
from .utils import InputType, LoadImage
8+
from .utils import InputType, LoadImage, OrtInferSession, ResizePad
109

1110
cur_dir = Path(__file__).resolve().parent
12-
table_cls_model_path = cur_dir / "models" / "table_cls.onnx"
11+
q_cls_model_path = cur_dir / "models" / "table_cls.onnx"
12+
yolo_cls_model_path = cur_dir / "models" / "yolo_cls.onnx"
1313

1414

1515
class TableCls:
16-
def __init__(self, device="cpu"):
17-
providers = (
18-
["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"]
19-
)
20-
self.table_cls = onnxruntime.InferenceSession(
21-
table_cls_model_path, providers=providers
22-
)
16+
def __init__(self, model="yolo"):
17+
if model == "yolo":
18+
self.table_engine = YoloCls()
19+
else:
20+
self.table_engine = QanythingCls()
21+
self.load_img = LoadImage()
22+
23+
def __call__(self, content: InputType):
24+
ss = time.perf_counter()
25+
img = self.load_img(content)
26+
img = self.table_engine.preprocess(img)
27+
predict_cla = self.table_engine([img])
28+
table_elapse = time.perf_counter() - ss
29+
return predict_cla, table_elapse
30+
31+
32+
class QanythingCls:
33+
def __init__(self):
34+
self.table_cls = OrtInferSession(q_cls_model_path)
2335
self.inp_h = 224
2436
self.inp_w = 224
2537
self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
2638
self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
2739
self.cls = {0: "wired", 1: "wireless"}
28-
self.load_img = LoadImage()
2940

30-
def _preprocess(self, image):
31-
img = Image.fromarray(np.uint8(image))
41+
def preprocess(self, img):
42+
img = cv2.cvtColor(img.copy(), cv2.COLOR_BGR2RGB)
43+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
44+
img = np.stack((img,) * 3, axis=-1)
45+
img = Image.fromarray(np.uint8(img))
3246
img = img.resize((self.inp_h, self.inp_w))
3347
img = np.array(img, dtype=np.float32) / 255.0
3448
img -= self.mean
@@ -37,15 +51,27 @@ def _preprocess(self, image):
3751
img = np.expand_dims(img, axis=0) # Add batch dimension, only one image
3852
return img
3953

40-
def __call__(self, content: InputType):
41-
ss = time.perf_counter()
42-
img = self.load_img(content)
43-
gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
44-
gray_img = np.stack((gray_img,) * 3, axis=-1)
45-
gray_img = self._preprocess(gray_img)
46-
output = self.table_cls.run(None, {"input": gray_img})
54+
def __call__(self, img):
55+
output = self.table_cls(img)
4756
predict = np.exp(output[0] - np.max(output[0], axis=1, keepdims=True))
4857
predict /= np.sum(predict, axis=1, keepdims=True)
4958
predict_cla = np.argmax(predict, axis=1)[0]
50-
table_elapse = time.perf_counter() - ss
51-
return self.cls[predict_cla], table_elapse
59+
return self.cls[predict_cla]
60+
61+
62+
class YoloCls:
63+
def __init__(self):
64+
self.table_cls = OrtInferSession(yolo_cls_model_path)
65+
self.cls = {0: "wireless", 1: "wired"}
66+
67+
def preprocess(self, img):
68+
img, *_ = ResizePad(img, 640)
69+
img = np.array(img, dtype=np.float32) / 255.0
70+
img = img.transpose(2, 0, 1) # HWC to CHW
71+
img = np.expand_dims(img, axis=0) # Add batch dimension, only one image
72+
return img
73+
74+
def __call__(self, img):
75+
output = self.table_cls(img)
76+
predict_cla = np.argmax(output[0], axis=1)[0]
77+
return self.cls[predict_cla]

table_cls/utils.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,86 @@
1+
import traceback
12
from io import BytesIO
23
from pathlib import Path
3-
from typing import Union
4+
from typing import Union, List
45

56
import cv2
67
import numpy as np
78
from PIL import Image, UnidentifiedImageError
9+
from onnxruntime import InferenceSession
10+
from onnxruntime.capi.onnxruntime_pybind11_state import (
11+
SessionOptions,
12+
GraphOptimizationLevel,
13+
)
814

915
InputType = Union[str, np.ndarray, bytes, Path, Image.Image]
1016

1117

18+
class OrtInferSession:
19+
def __init__(self, model_path: Union[str, Path], num_threads: int = -1):
20+
self.verify_exist(model_path)
21+
22+
self.num_threads = num_threads
23+
self._init_sess_opt()
24+
25+
cpu_ep = "CPUExecutionProvider"
26+
cpu_provider_options = {
27+
"arena_extend_strategy": "kSameAsRequested",
28+
}
29+
EP_list = [(cpu_ep, cpu_provider_options)]
30+
try:
31+
self.session = InferenceSession(
32+
str(model_path), sess_options=self.sess_opt, providers=EP_list
33+
)
34+
except TypeError:
35+
# 这里兼容ort 1.5.2
36+
self.session = InferenceSession(str(model_path), sess_options=self.sess_opt)
37+
38+
def _init_sess_opt(self):
39+
self.sess_opt = SessionOptions()
40+
self.sess_opt.log_severity_level = 4
41+
self.sess_opt.enable_cpu_mem_arena = False
42+
43+
if self.num_threads != -1:
44+
self.sess_opt.intra_op_num_threads = self.num_threads
45+
46+
self.sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
47+
48+
def __call__(self, input_content: List[np.ndarray]) -> np.ndarray:
49+
input_dict = dict(zip(self.get_input_names(), input_content))
50+
try:
51+
return self.session.run(None, input_dict)
52+
except Exception as e:
53+
error_info = traceback.format_exc()
54+
raise ONNXRuntimeError(error_info) from e
55+
56+
def get_input_names(
57+
self,
58+
):
59+
return [v.name for v in self.session.get_inputs()]
60+
61+
def get_output_name(self, output_idx=0):
62+
return self.session.get_outputs()[output_idx].name
63+
64+
def get_metadata(self):
65+
meta_dict = self.session.get_modelmeta().custom_metadata_map
66+
return meta_dict
67+
68+
@staticmethod
69+
def verify_exist(model_path: Union[Path, str]):
70+
if not isinstance(model_path, Path):
71+
model_path = Path(model_path)
72+
73+
if not model_path.exists():
74+
raise FileNotFoundError(f"{model_path} does not exist!")
75+
76+
if not model_path.is_file():
77+
raise FileExistsError(f"{model_path} must be a file")
78+
79+
80+
class ONNXRuntimeError(Exception):
81+
pass
82+
83+
1284
class LoadImageError(Exception):
1385
pass
1486

@@ -106,3 +178,19 @@ def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
106178
def verify_exist(file_path: Union[str, Path]):
107179
if not Path(file_path).exists():
108180
raise LoadImageError(f"{file_path} does not exist.")
181+
182+
183+
def ResizePad(img, target_size):
184+
h, w = img.shape[:2]
185+
m = max(h, w)
186+
ratio = target_size / m
187+
new_w, new_h = int(ratio * w), int(ratio * h)
188+
img = cv2.resize(img, (new_w, new_h), cv2.INTER_LINEAR)
189+
top = (target_size - new_h) // 2
190+
bottom = (target_size - new_h) - top
191+
left = (target_size - new_w) // 2
192+
right = (target_size - new_w) - left
193+
img1 = cv2.copyMakeBorder(
194+
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
195+
)
196+
return img1, new_w, new_h, left, top

0 commit comments

Comments
 (0)