Skip to content

Commit fe60246

Browse files
authored
Merge pull request #70 from RapidAI/add_param_for_ocr
Add param for ocr
2 parents eaaf4d3 + 1dbfec3 commit fe60246

File tree

8 files changed

+205
-35
lines changed

8 files changed

+205
-35
lines changed

README.md

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,16 @@
1313
</div>
1414

1515
### 最近更新
16-
- **2024.10.13**
17-
- 补充最新paddlex-SLANet-plus 测评结果(已集成模型到[RapidTable](https://github.com/RapidAI/RapidTable)仓库)
1816
- **2024.10.22**
1917
- 补充复杂背景多表格检测提取方案[RapidTableDet](https://github.com/RapidAI/RapidTableDetection)
2018
- **2024.10.29**
2119
- 使用yolo11重新训练表格分类器,修正wired_table_rec v2逻辑坐标还原错误,并更新测评
20+
- **2024.11.12**
21+
- 抽离模型识别和处理过程核心阈值,方便大家进行微调适配自己的场景[微调入参参考](#核心参数)
2222

2323
### 简介
24-
💖该仓库是用来对文档中表格做结构化识别的推理库,包括来自阿里读光有线和无线表格识别模型,llaipython(微信)贡献的有线表格模型,网易Qanything内置表格分类模型等。
25-
24+
💖该仓库是用来对文档中表格做结构化识别的推理库,包括来自阿里读光有线和无线表格识别模型,llaipython(微信)贡献的有线表格模型,网易Qanything内置表格分类模型等。\
25+
[快速开始](#安装) [模型评测](#指标结果) [使用建议](#使用建议) [表格旋转及透视修正](#表格旋转及透视修正) [微调入参参考](#核心参数) [常见问题](#FAQ) [更新计划](#更新计划)
2626
#### 特点
2727

2828
**** 采用ONNXRuntime作为推理引擎,cpu下单图推理1-7s
@@ -68,6 +68,7 @@
6868
wired_table_rec_v2(有线表格精度最高): 通用场景有线表格(论文,杂志,期刊, 收据,单据,账单)
6969

7070
paddlex-SLANet-plus(综合精度最高): 文档场景表格(论文,杂志,期刊中的表格)
71+
[微调入参参考](#核心参数)
7172

7273
### 安装
7374

@@ -158,7 +159,30 @@ for i, res in enumerate(result):
158159
# cv2.imwrite(f"{out_dir}/{file_name}-visualize.jpg", img)
159160
```
160161

161-
## FAQ (Frequently Asked Questions)
162+
### 核心参数
163+
```python
164+
wired_table_rec = WiredTableRecognition()
165+
html, elasp, polygons, logic_points, ocr_res = wired_table_rec(
166+
img_path,
167+
version="v2", #默认使用v2线框模型,切换阿里读光模型可改为v1
168+
morph_close=True, # 是否进行形态学操作,辅助找到更多线框,默认为True
169+
more_h_lines=True, # 是否基于线框检测结果进行更多水平线检查,辅助找到更小线框, 默认为True
170+
h_lines_threshold = 100, # 必须开启more_h_lines, 连接横线检测像素阈值,小于该值会生成新横线,默认为100
171+
more_v_lines=True, # 是否基于线框检测结果进行更多垂直线检查,辅助找到更小线框, 默认为True
172+
v_lines_threshold = 15, # 必须开启more_v_lines, 连接竖线检测像素阈值,小于该值会生成新竖线,默认为15
173+
extend_line=True, # 是否基于线框检测结果进行线段延长,辅助找到更多线框, 默认为True
174+
need_ocr=True, # 是否进行OCR识别, 默认为True
175+
rec_again=True,# 是否针对未识别到文字的表格框,进行单独截取再识别,默认为True
176+
)
177+
lineless_table_rec = LinelessTableRecognition()
178+
html, elasp, polygons, logic_points, ocr_res = lineless_table_rec(
179+
need_ocr=True, # 是否进行OCR识别, 默认为True
180+
rec_again=True,# 是否针对未识别到文字的表格框,进行单独截取再识别,默认为True
181+
)
182+
```
183+
184+
185+
## FAQ
162186
1. **问:识别框丢失了内部文字信息**
163187
- 答:默认使用的rapidocr小模型,如果需要更高精度的效果,可以从 [模型列表](https://rapidai.github.io/RapidOCRDocs/model_list/#_1)
164188
下载更高精度的ocr模型,在执行时传入ocr_result即可
@@ -168,7 +192,7 @@ for i, res in enumerate(result):
168192
主要耗时在ocr阶段,可以参考 [rapidocr_paddle](https://rapidai.github.io/RapidOCRDocs/install_usage/rapidocr_paddle/usage/#_3)
169193
加速ocr识别过程
170194

171-
### TODO List
195+
### 更新计划
172196

173197
- [x] 图片小角度偏移修正方法补充
174198
- [x] 增加数据集数量,增加更多评测对比

lineless_table_rec/main.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,23 +51,37 @@ def __call__(
5151
self,
5252
content: InputType,
5353
ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None,
54+
**kwargs
5455
):
5556
ss = time.perf_counter()
57+
rec_again = True
58+
need_ocr = True
59+
if kwargs:
60+
rec_again = kwargs.get("rec_again", True)
61+
need_ocr = kwargs.get("need_ocr", True)
5662
img = self.load_img(content)
57-
if self.ocr is None and ocr_result is None:
58-
raise ValueError(
59-
"One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed."
60-
)
61-
if ocr_result is None:
62-
ocr_result, _ = self.ocr(img)
6363
input_info = self.preprocess(img)
6464
try:
6565
polygons, slct_logi = self.infer(input_info)
6666
logi_points = self.filter_logi_points(slct_logi)
67+
if not need_ocr:
68+
sorted_polygons, idx_list = sorted_ocr_boxes(
69+
[box_4_2_poly_to_box_4_1(box) for box in polygons]
70+
)
71+
return (
72+
"",
73+
time.perf_counter() - ss,
74+
sorted_polygons,
75+
logi_points[idx_list],
76+
[],
77+
)
78+
79+
if ocr_result is None and need_ocr:
80+
ocr_result, _ = self.ocr(img)
6781
# ocr 结果匹配
6882
cell_box_det_map, no_match_ocr_det = match_ocr_cell(ocr_result, polygons)
6983
# 如果有识别框没有ocr结果,直接进行rec补充
70-
cell_box_det_map = self.re_rec(img, polygons, cell_box_det_map)
84+
cell_box_det_map = self.re_rec(img, polygons, cell_box_det_map, rec_again)
7185
# 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理
7286
t_rec_ocr_list = self.transform_res(cell_box_det_map, polygons, logi_points)
7387
# 拆分包含和重叠的识别框
@@ -81,7 +95,6 @@ def __call__(
8195
]
8296
# 生成行列对应的二维表格, 合并同行同列识别框中的的ocr识别框
8397
t_rec_ocr_list, grid = self.handle_overlap_row_col(t_rec_ocr_list)
84-
# todo 根据grid 及 not_match_orc_boxes,尝试将ocr识别填入单行单列中
8598
# 将同一个识别框中的ocr结果排序并同行合并
8699
t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list)
87100
# 渲染为html
@@ -192,11 +205,11 @@ def infer(self, input_content: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray]:
192205
def sort_and_gather_ocr_res(self, res):
193206
for i, dict_res in enumerate(res):
194207
_, sorted_idx = sorted_ocr_boxes(
195-
[ocr_det[0] for ocr_det in dict_res["t_ocr_res"]], threhold=0.5
208+
[ocr_det[0] for ocr_det in dict_res["t_ocr_res"]], threhold=0.3
196209
)
197210
dict_res["t_ocr_res"] = [dict_res["t_ocr_res"][i] for i in sorted_idx]
198211
dict_res["t_ocr_res"] = gather_ocr_list_by_row(
199-
dict_res["t_ocr_res"], thehold=0.5
212+
dict_res["t_ocr_res"], thehold=0.3
200213
)
201214
return res
202215

@@ -263,12 +276,17 @@ def re_rec(
263276
img: np.ndarray,
264277
sorted_polygons: np.ndarray,
265278
cell_box_map: Dict[int, List[str]],
279+
rec_again=True,
266280
) -> Dict[int, List[any]]:
267281
"""找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
268282
#
269283
for i in range(sorted_polygons.shape[0]):
270284
if cell_box_map.get(i):
271285
continue
286+
if not rec_again:
287+
box = sorted_polygons[i]
288+
cell_box_map[i] = [[box, "", 1]]
289+
continue
272290
crop_img = get_rotate_crop_image(img, sorted_polygons[i])
273291
pad_img = cv2.copyMakeBorder(
274292
crop_img, 5, 5, 100, 100, cv2.BORDER_CONSTANT, value=(255, 255, 255)
175 KB
Loading

tests/test_lineless_table_rec.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,39 @@ def test_plot_html_table(logi_points, cell_box_map, expected_html):
244244
assert (
245245
html_output == expected_html
246246
), f"Expected HTML does not match. Got: {html_output}"
247+
248+
249+
@pytest.mark.parametrize(
250+
"img_path, table_str_len, td_nums",
251+
[
252+
("table.jpg", 2870, 160),
253+
],
254+
)
255+
def test_no_rec_again(img_path, table_str_len, td_nums):
256+
img_path = test_file_dir / img_path
257+
img = cv2.imread(str(img_path))
258+
259+
table_str, *_ = table_recog(img, rec_again=False)
260+
261+
assert len(table_str) >= table_str_len
262+
assert table_str.count("td") == td_nums
263+
264+
265+
@pytest.mark.parametrize(
266+
"img_path, html_output, points_len",
267+
[
268+
("table.jpg", "", 77),
269+
("lineless_table_recognition.jpg", "", 51),
270+
],
271+
)
272+
def test_no_ocr(img_path, html_output, points_len):
273+
img_path = test_file_dir / img_path
274+
275+
html, elasp, polygons, logic_points, ocr_res = table_recog(
276+
str(img_path), need_ocr=False
277+
)
278+
assert len(ocr_res) == 0
279+
assert len(polygons) > points_len
280+
assert len(logic_points) > points_len
281+
assert len(polygons) == len(logic_points)
282+
assert html == html_output

tests/test_wired_table_rec.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,22 @@ def test_input_normal(img_path, gt_td_nums, gt2):
6565
assert td_nums >= gt_td_nums
6666

6767

68+
@pytest.mark.parametrize(
69+
"img_path, gt_td_nums",
70+
[
71+
("wired_big_box.png", 70),
72+
],
73+
)
74+
def test_input_normal(img_path, gt_td_nums):
75+
img_path = test_file_dir / img_path
76+
77+
ocr_result, _ = ocr_engine(img_path)
78+
table_str, *_ = table_recog(str(img_path), ocr_result)
79+
td_nums = get_td_nums(table_str)
80+
81+
assert td_nums >= gt_td_nums
82+
83+
6884
@pytest.mark.parametrize(
6985
"box1, box2, threshold, expected",
7086
[
@@ -264,3 +280,40 @@ def test_plot_html_table(logi_points, cell_box_map, expected_html):
264280
assert (
265281
html_output == expected_html
266282
), f"Expected HTML does not match. Got: {html_output}"
283+
284+
285+
@pytest.mark.parametrize(
286+
"img_path, gt_td_nums, gt2",
287+
[
288+
("table_recognition.jpg", 35, "d colsp"),
289+
],
290+
)
291+
def test_no_rec_again(img_path, gt_td_nums, gt2):
292+
img_path = test_file_dir / img_path
293+
294+
ocr_result, _ = ocr_engine(img_path)
295+
table_str, *_ = table_recog(str(img_path), ocr_result, rec_again=False)
296+
td_nums = get_td_nums(table_str)
297+
298+
assert td_nums >= gt_td_nums
299+
300+
301+
@pytest.mark.parametrize(
302+
"img_path, html_output, points_len",
303+
[
304+
("table2.jpg", "", 20),
305+
("row_span.png", "", 14),
306+
],
307+
)
308+
def test_no_ocr(img_path, html_output, points_len):
309+
img_path = test_file_dir / img_path
310+
311+
ocr_result, _ = ocr_engine(img_path)
312+
html, elasp, polygons, logic_points, ocr_res = table_recog(
313+
str(img_path), ocr_result, need_ocr=False
314+
)
315+
assert len(ocr_res) == 0
316+
assert len(polygons) > points_len
317+
assert len(logic_points) > points_len
318+
assert len(polygons) == len(logic_points)
319+
assert html == html_output

wired_table_rec/main.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,21 @@ def __call__(
5050
self,
5151
img: InputType,
5252
ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None,
53+
**kwargs,
5354
) -> Tuple[str, float, Any, Any, Any]:
5455
if self.ocr is None and ocr_result is None:
5556
raise ValueError(
5657
"One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed."
5758
)
5859

5960
s = time.perf_counter()
60-
61+
rec_again = True
62+
need_ocr = True
63+
if kwargs:
64+
rec_again = kwargs.get("rec_again", True)
65+
need_ocr = kwargs.get("need_ocr", True)
6166
img = self.load_img(img)
62-
polygons = self.table_line_rec(img)
67+
polygons = self.table_line_rec(img, **kwargs)
6368
if polygons is None:
6469
logging.warning("polygons is None.")
6570
return "", 0.0, None, None, None
@@ -71,12 +76,22 @@ def __call__(
7176
polygons[:, 3, :].copy(),
7277
polygons[:, 1, :].copy(),
7378
)
74-
if ocr_result is None:
79+
if not need_ocr:
80+
sorted_polygons, idx_list = sorted_ocr_boxes(
81+
[box_4_2_poly_to_box_4_1(box) for box in polygons]
82+
)
83+
return (
84+
"",
85+
time.perf_counter() - s,
86+
sorted_polygons,
87+
logi_points[idx_list],
88+
[],
89+
)
90+
if ocr_result is None and need_ocr:
7591
ocr_result, _ = self.ocr(img)
7692
cell_box_det_map, not_match_orc_boxes = match_ocr_cell(ocr_result, polygons)
7793
# 如果有识别框没有ocr结果,直接进行rec补充
78-
# cell_box_det_map = self.re_rec_high_precise(img, polygons, cell_box_det_map)
79-
cell_box_det_map = self.re_rec(img, polygons, cell_box_det_map)
94+
cell_box_det_map = self.re_rec(img, polygons, cell_box_det_map, rec_again)
8095
# 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理
8196
t_rec_ocr_list = self.transform_res(cell_box_det_map, polygons, logi_points)
8297
# 将每个单元格中的ocr识别结果排序和同行合并,输出的html能完整保留文字的换行格式
@@ -139,11 +154,11 @@ def transform_res(
139154
def sort_and_gather_ocr_res(self, res):
140155
for i, dict_res in enumerate(res):
141156
_, sorted_idx = sorted_ocr_boxes(
142-
[ocr_det[0] for ocr_det in dict_res["t_ocr_res"]], threhold=0.5
157+
[ocr_det[0] for ocr_det in dict_res["t_ocr_res"]], threhold=0.3
143158
)
144159
dict_res["t_ocr_res"] = [dict_res["t_ocr_res"][i] for i in sorted_idx]
145160
dict_res["t_ocr_res"] = gather_ocr_list_by_row(
146-
dict_res["t_ocr_res"], threhold=0.5
161+
dict_res["t_ocr_res"], threhold=0.3
147162
)
148163
return res
149164

@@ -152,12 +167,16 @@ def re_rec(
152167
img: np.ndarray,
153168
sorted_polygons: np.ndarray,
154169
cell_box_map: Dict[int, List[str]],
170+
rec_again=True,
155171
) -> Dict[int, List[Any]]:
156172
"""找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
157-
#
158173
for i in range(sorted_polygons.shape[0]):
159174
if cell_box_map.get(i):
160175
continue
176+
if not rec_again:
177+
box = sorted_polygons[i]
178+
cell_box_map[i] = [[box, "", 1]]
179+
continue
161180
crop_img = get_rotate_crop_image(img, sorted_polygons[i])
162181
pad_img = cv2.copyMakeBorder(
163182
crop_img, 5, 5, 100, 100, cv2.BORDER_CONSTANT, value=(255, 255, 255)

wired_table_rec/table_line_rec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self, model_path: Optional[str] = None):
3636

3737
self.session = OrtInferSession(model_path)
3838

39-
def __call__(self, img: np.ndarray) -> Optional[np.ndarray]:
39+
def __call__(self, img: np.ndarray, **kwargs) -> Optional[np.ndarray]:
4040
img_info = self.preprocess(img)
4141
pred = self.infer(img_info)
4242
polygons = self.postprocess(pred)

0 commit comments

Comments
 (0)