Skip to content

Commit 9ea4a74

Browse files
committed
feat: optim param use for table cls
1 parent 490f328 commit 9ea4a74

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

table_cls/main.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313

1414

1515
class TableCls:
16-
def __init__(self, model="yolo"):
17-
if model == "yolo":
18-
self.table_engine = YoloCls()
16+
def __init__(self, model_type="yolo", model_path=yolo_cls_model_path):
17+
if model_type == "yolo":
18+
self.table_engine = YoloCls(model_path)
1919
else:
20-
self.table_engine = QanythingCls()
20+
model_path = q_cls_model_path
21+
self.table_engine = QanythingCls(model_path)
2122
self.load_img = LoadImage()
2223

2324
def __call__(self, content: InputType):
@@ -30,8 +31,8 @@ def __call__(self, content: InputType):
3031

3132

3233
class QanythingCls:
33-
def __init__(self):
34-
self.table_cls = OrtInferSession(q_cls_model_path)
34+
def __init__(self, model_path):
35+
self.table_cls = OrtInferSession(model_path)
3536
self.inp_h = 224
3637
self.inp_w = 224
3738
self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
@@ -60,8 +61,8 @@ def __call__(self, img):
6061

6162

6263
class YoloCls:
63-
def __init__(self):
64-
self.table_cls = OrtInferSession(yolo_cls_model_path)
64+
def __init__(self, model_path):
65+
self.table_cls = OrtInferSession(model_path)
6566
self.cls = {0: "wireless", 1: "wired"}
6667

6768
def preprocess(self, img):

tests/test_table_cls.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
@pytest.mark.parametrize(
1717
"img_path, expected",
18-
[("wired_table.png", "wired"), ("lineless_table.png", "wireless")],
18+
[("wired_table.jpg", "wired"), ("lineless_table.png", "wireless")],
1919
)
2020
def test_input_normal(img_path, expected):
2121
img_path = test_file_dir / img_path

0 commit comments

Comments
 (0)