3
3
4
4
import cv2
5
5
import numpy as np
6
- import onnxruntime
7
6
from PIL import Image
8
7
9
- from .utils import InputType , LoadImage
8
+ from .utils import InputType , LoadImage , OrtInferSession , ResizePad
10
9
11
10
cur_dir = Path (__file__ ).resolve ().parent
12
- table_cls_model_path = cur_dir / "models" / "table_cls.onnx"
11
+ q_cls_model_path = cur_dir / "models" / "table_cls.onnx"
12
+ yolo_cls_model_path = cur_dir / "models" / "yolo_cls.onnx"
13
13
14
14
15
15
class TableCls :
16
- def __init__ (self , device = "cpu" ):
17
- providers = (
18
- ["CUDAExecutionProvider" ] if device == "cuda" else ["CPUExecutionProvider" ]
19
- )
20
- self .table_cls = onnxruntime .InferenceSession (
21
- table_cls_model_path , providers = providers
22
- )
16
+ def __init__ (self , model = "yolo" ):
17
+ if model == "yolo" :
18
+ self .table_engine = YoloCls ()
19
+ else :
20
+ self .table_engine = QanythingCls ()
21
+ self .load_img = LoadImage ()
22
+
23
+ def __call__ (self , content : InputType ):
24
+ ss = time .perf_counter ()
25
+ img = self .load_img (content )
26
+ img = self .table_engine .preprocess (img )
27
+ predict_cla = self .table_engine ([img ])
28
+ table_elapse = time .perf_counter () - ss
29
+ return predict_cla , table_elapse
30
+
31
+
32
+ class QanythingCls :
33
+ def __init__ (self ):
34
+ self .table_cls = OrtInferSession (q_cls_model_path )
23
35
self .inp_h = 224
24
36
self .inp_w = 224
25
37
self .mean = np .array ([0.485 , 0.456 , 0.406 ], dtype = np .float32 )
26
38
self .std = np .array ([0.229 , 0.224 , 0.225 ], dtype = np .float32 )
27
39
self .cls = {0 : "wired" , 1 : "wireless" }
28
- self .load_img = LoadImage ()
29
40
30
- def _preprocess (self , image ):
31
- img = Image .fromarray (np .uint8 (image ))
41
+ def preprocess (self , img ):
42
+ img = cv2 .cvtColor (img .copy (), cv2 .COLOR_BGR2RGB )
43
+ img = cv2 .cvtColor (img , cv2 .COLOR_BGR2GRAY )
44
+ img = np .stack ((img ,) * 3 , axis = - 1 )
45
+ img = Image .fromarray (np .uint8 (img ))
32
46
img = img .resize ((self .inp_h , self .inp_w ))
33
47
img = np .array (img , dtype = np .float32 ) / 255.0
34
48
img -= self .mean
@@ -37,15 +51,27 @@ def _preprocess(self, image):
37
51
img = np .expand_dims (img , axis = 0 ) # Add batch dimension, only one image
38
52
return img
39
53
40
- def __call__ (self , content : InputType ):
41
- ss = time .perf_counter ()
42
- img = self .load_img (content )
43
- gray_img = cv2 .cvtColor (img , cv2 .COLOR_BGR2GRAY )
44
- gray_img = np .stack ((gray_img ,) * 3 , axis = - 1 )
45
- gray_img = self ._preprocess (gray_img )
46
- output = self .table_cls .run (None , {"input" : gray_img })
54
+ def __call__ (self , img ):
55
+ output = self .table_cls (img )
47
56
predict = np .exp (output [0 ] - np .max (output [0 ], axis = 1 , keepdims = True ))
48
57
predict /= np .sum (predict , axis = 1 , keepdims = True )
49
58
predict_cla = np .argmax (predict , axis = 1 )[0 ]
50
- table_elapse = time .perf_counter () - ss
51
- return self .cls [predict_cla ], table_elapse
59
+ return self .cls [predict_cla ]
60
+
61
+
62
+ class YoloCls :
63
+ def __init__ (self ):
64
+ self .table_cls = OrtInferSession (yolo_cls_model_path )
65
+ self .cls = {0 : "wireless" , 1 : "wired" }
66
+
67
+ def preprocess (self , img ):
68
+ img , * _ = ResizePad (img , 640 )
69
+ img = np .array (img , dtype = np .float32 ) / 255.0
70
+ img = img .transpose (2 , 0 , 1 ) # HWC to CHW
71
+ img = np .expand_dims (img , axis = 0 ) # Add batch dimension, only one image
72
+ return img
73
+
74
+ def __call__ (self , img ):
75
+ output = self .table_cls (img )
76
+ predict_cla = np .argmax (output [0 ], axis = 1 )[0 ]
77
+ return self .cls [predict_cla ]
0 commit comments