Skip to content

Commit eaaf4d3

Browse files
authored
Merge pull request #65 from RapidAI/fix_table_cls_preprocess
fix: fix table cls preprocess
2 parents 881a164 + 2c3939b commit eaaf4d3

File tree

4 files changed

+38
-4
lines changed

4 files changed

+38
-4
lines changed

lineless_table_rec/utils_table_recover.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def gather_ocr_list_by_row(ocr_list: List[Any], thehold: float = 0.2) -> List[An
289289
cur[0], next[0], axis="y", threhold=thehold
290290
)
291291
if c_idx:
292-
dis = max(next_box[0] - cur_box[1], 0)
292+
dis = max(next_box[0] - cur_box[0], 0)
293293
blank_str = int(dis / threshold) * " "
294294
cur[1] = cur[1] + blank_str + next[1]
295295
xmin = min(cur_box[0], next_box[0])

table_cls/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
from PIL import Image
77

8-
from .utils import InputType, LoadImage, OrtInferSession
8+
from .utils import InputType, LoadImage, OrtInferSession, resize_and_center_crop
99

1010
cur_dir = Path(__file__).resolve().parent
1111
q_cls_model_path = cur_dir / "models" / "table_cls.onnx"
@@ -70,7 +70,7 @@ def __init__(self, model_path):
7070

7171
def preprocess(self, img):
7272
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
73-
img = cv2.resize(img, (640, 640))
73+
img = resize_and_center_crop(img, 640)
7474
img = np.array(img, dtype=np.float32) / 255
7575
img = img.transpose(2, 0, 1) # HWC to CHW
7676
img = np.expand_dims(img, axis=0) # Add batch dimension, only one image

table_cls/utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,37 @@ def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
178178
def verify_exist(file_path: Union[str, Path]):
179179
if not Path(file_path).exists():
180180
raise LoadImageError(f"{file_path} does not exist.")
181+
182+
183+
def resize_and_center_crop(image, output_size=640):
184+
"""
185+
将图片的最小边缩放到指定大小,并进行中心裁剪。
186+
187+
:param image: 输入的图片数组 (H, W, C)
188+
:param output_size: 缩放和裁剪后的图片大小,默认为 640
189+
:return: 处理后的图片数组 (output_size, output_size, C)
190+
"""
191+
# 获取图片的高度和宽度
192+
height, width = image.shape[:2]
193+
# 计算缩放比例
194+
if width < height:
195+
new_width = output_size
196+
new_height = int(output_size * height / width)
197+
else:
198+
new_width = int(output_size * width / height)
199+
new_height = output_size
200+
201+
# 缩放图片
202+
image_resize = cv2.resize(
203+
image, (new_width, new_height), interpolation=cv2.INTER_LINEAR
204+
)
205+
206+
# 计算中心裁剪的坐标
207+
left = (new_width - output_size) // 2
208+
top = (new_height - output_size) // 2
209+
right = left + output_size
210+
bottom = top + output_size
211+
212+
# # 中心裁剪
213+
image_cropped = image_resize[top:bottom, left:right]
214+
return image_cropped

wired_table_rec/utils_table_recover.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ def gather_ocr_list_by_row(ocr_list: List[Any], threhold: float = 0.2) -> List[A
383383
cur[0], next[0], axis="y", threhold=threhold
384384
)
385385
if c_idx:
386-
dis = max(next_box[0] - cur_box[1], 0)
386+
dis = max(next_box[0] - cur_box[0], 0)
387387
blank_str = int(dis / threshold) * " "
388388
cur[1] = cur[1] + blank_str + next[1]
389389
xmin = min(cur_box[0], next_box[0])

0 commit comments

Comments
 (0)