13
13
14
14
15
15
class TableCls :
16
- def __init__ (self , model = "yolo" ):
17
- if model == "yolo" :
18
- self .table_engine = YoloCls ()
16
+ def __init__ (self , model_type = "yolo" , model_path = yolo_cls_model_path ):
17
+ if model_type == "yolo" :
18
+ self .table_engine = YoloCls (model_path )
19
19
else :
20
- self .table_engine = QanythingCls ()
20
+ model_path = q_cls_model_path
21
+ self .table_engine = QanythingCls (model_path )
21
22
self .load_img = LoadImage ()
22
23
23
24
def __call__ (self , content : InputType ):
@@ -30,8 +31,8 @@ def __call__(self, content: InputType):
30
31
31
32
32
33
class QanythingCls :
33
- def __init__ (self ):
34
- self .table_cls = OrtInferSession (q_cls_model_path )
34
+ def __init__ (self , model_path ):
35
+ self .table_cls = OrtInferSession (model_path )
35
36
self .inp_h = 224
36
37
self .inp_w = 224
37
38
self .mean = np .array ([0.485 , 0.456 , 0.406 ], dtype = np .float32 )
@@ -60,8 +61,8 @@ def __call__(self, img):
60
61
61
62
62
63
class YoloCls :
63
- def __init__ (self ):
64
- self .table_cls = OrtInferSession (yolo_cls_model_path )
64
+ def __init__ (self , model_path ):
65
+ self .table_cls = OrtInferSession (model_path )
65
66
self .cls = {0 : "wireless" , 1 : "wired" }
66
67
67
68
def preprocess (self , img ):
0 commit comments