diff --git a/src/yomitoku/cli/main.py b/src/yomitoku/cli/main.py index 2abee99..3a5b33f 100644 --- a/src/yomitoku/cli/main.py +++ b/src/yomitoku/cli/main.py @@ -1,5 +1,6 @@ import argparse import os +import torch from pathlib import Path import cv2 @@ -235,7 +236,7 @@ def main(): if args.lite: configs["ocr"]["text_recognizer"]["model_name"] = "parseq-small" - if args.device == "cpu": + if args.device == "cpu" or not torch.cuda.is_available(): configs["ocr"]["text_detector"]["infer_onnx"] = True # Note: Text Detector以外はONNX推論よりもPyTorch推論の方が速いため、ONNX推論は行わない diff --git a/src/yomitoku/document_analyzer.py b/src/yomitoku/document_analyzer.py index 2a24d79..f39a319 100644 --- a/src/yomitoku/document_analyzer.py +++ b/src/yomitoku/document_analyzer.py @@ -127,7 +127,6 @@ def extract_words_within_element(pred_words, element): if len(contained_words) == 0: return None, None, check_list - element_direction = "horizontal" word_direction = [word.direction for word in contained_words] cnt_horizontal = word_direction.count("horizontal") cnt_vertical = word_direction.count("vertical") diff --git a/src/yomitoku/export/export_csv.py b/src/yomitoku/export/export_csv.py index 48cd7dd..2247db6 100644 --- a/src/yomitoku/export/export_csv.py +++ b/src/yomitoku/export/export_csv.py @@ -41,6 +41,8 @@ def save_figure( out_path, figure_dir="figures", ): + assert img is not None, "img is required for saving figures" + for i, figure in enumerate(figures): x1, y1, x2, y2 = map(int, figure.box) figure_img = img[y1:y2, x1:x2, :] diff --git a/src/yomitoku/export/export_html.py b/src/yomitoku/export/export_html.py index 3670f96..180b975 100644 --- a/src/yomitoku/export/export_html.py +++ b/src/yomitoku/export/export_html.py @@ -110,6 +110,8 @@ def figure_to_html( figure_dir="figures", width=200, ): + assert img is not None, "img is required for saving figures" + elements = [] for i, figure in enumerate(figures): x1, y1, x2, y2 = map(int, figure.box) diff --git a/src/yomitoku/export/export_json.py b/src/yomitoku/export/export_json.py index 819ea05..3b41c2a 100644 --- a/src/yomitoku/export/export_json.py +++ b/src/yomitoku/export/export_json.py @@ -21,6 +21,8 @@ def save_figure( out_path, figure_dir="figures", ): + assert img is not None, "img is required for saving figures" + for i, figure in enumerate(figures): x1, y1, x2, y2 = map(int, figure.box) figure_img = img[y1:y2, x1:x2, :] diff --git a/src/yomitoku/export/export_markdown.py b/src/yomitoku/export/export_markdown.py index da9c54e..ebf5811 100644 --- a/src/yomitoku/export/export_markdown.py +++ b/src/yomitoku/export/export_markdown.py @@ -75,6 +75,8 @@ def figure_to_md( width=200, figure_dir="figures", ): + assert img is not None, "img is required for saving figures" + elements = [] for i, figure in enumerate(figures): x1, y1, x2, y2 = map(int, figure.box) diff --git a/tests/test_cli.py b/tests/test_cli.py index 9be86b9..2a49795 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -210,6 +210,7 @@ def test_validate_encoding(): validate_encoding("utf-9") assert validate_encoding("utf-8") + assert validate_encoding("utf-8-sig") assert validate_encoding("shift-jis") assert validate_encoding("euc-jp") assert validate_encoding("cp932") diff --git a/tests/test_document_analyzer.py b/tests/test_document_analyzer.py index 2365c3b..dbb4b5b 100644 --- a/tests/test_document_analyzer.py +++ b/tests/test_document_analyzer.py @@ -3,6 +3,36 @@ from omegaconf import OmegaConf from yomitoku import DocumentAnalyzer +from yomitoku.document_analyzer import ( + ParagraphSchema, + FigureSchema, + DocumentAnalyzerSchema, + extract_paragraph_within_figure, + combine_flags, + judge_page_direction, + extract_words_within_element, + is_vertical, + is_noise, + recursive_update, + _extract_words_within_table, + _calc_overlap_words_on_lines, + _correct_vertical_word_boxes, + _correct_horizontal_word_boxes, + _split_text_across_cells, +) + + +from yomitoku.text_detector import TextDetectorSchema + +from yomitoku.table_structure_recognizer import ( + TableStructureRecognizerSchema, + TableLineSchema, + TableCellSchema, +) + +from yomitoku.ocr import ( + WordPrediction, +) def test_initialize(): @@ -86,3 +116,481 @@ def test_invalid_config(): DocumentAnalyzer( configs="invalid", ) + + +def test_extract_paragraph_within_figure(): + paragraphs = [ + { + "box": [0, 0, 2, 1], + "contents": "This is a test.", + "direction": "horizontal", + "order": 1, + "role": None, + }, + { + "box": [0, 0, 1, 2], + "contents": "This is a test.", + "direction": "vertical", + "order": 1, + "role": None, + }, + { + "box": [10, 10, 1, 2], + "contents": "This is a test.", + "direction": "horizontal", + "order": 1, + "role": None, + }, + ] + + figures = [ + { + "box": [0, 0, 2, 2], + "order": 1, + "paragraphs": [], + "direction": None, + } + ] + + paragraphs = [ParagraphSchema(**paragraph) for paragraph in paragraphs] + figures = [FigureSchema(**figure) for figure in figures] + + figures, checklist = extract_paragraph_within_figure(paragraphs, figures) + + assert checklist == [True, True, False] + assert len(figures[0].paragraphs) == 2 + + +def test_combile_flags(): + flags1 = [True, False, True] + flags2 = [False, False, True] + + assert combine_flags(flags1, flags2) == [True, False, True] + + +def test_judge_page_direction(): + paragraphs = [ + { + "box": [0, 0, 2, 1], + "contents": "This is a test.", + "direction": "horizontal", + "order": 1, + "role": None, + }, + { + "box": [0, 0, 1, 2], + "contents": "This is a test.", + "direction": "vertical", + "order": 1, + "role": None, + }, + { + "box": [10, 10, 1, 2], + "contents": "This is a test.", + "direction": "horizontal", + "order": 1, + "role": None, + }, + ] + + paragraphs = [ParagraphSchema(**paragraph) for paragraph in paragraphs] + assert judge_page_direction(paragraphs) == "horizontal" + + paragraphs = [ + { + "box": [0, 0, 2, 1], + "contents": "This is a test.", + "direction": "horizontal", + "order": 1, + "role": None, + }, + { + "box": [0, 0, 1, 2], + "contents": "This is a test.", + "direction": "vertical", + "order": 1, + "role": None, + }, + { + "box": [10, 10, 2, 1], + "contents": "This is a test.", + "direction": "vertical", + "order": 1, + "role": None, + }, + ] + + paragraphs = [ParagraphSchema(**paragraph) for paragraph in paragraphs] + assert judge_page_direction(paragraphs) == "vertical" + + +def test_extract_words_within_element(): + paragraph = { + "box": [0, 0, 1, 1], + "contents": "This is a test.", + "direction": "horizontal", + "order": 1, + "role": None, + } + + element = ParagraphSchema(**paragraph) + + words = [ + { + "points": [[10, 10], [11, 10], [11, 11], [10, 11]], + "content": "This", + "direction": "horizontal", + "rec_score": 0.9, + "det_score": 0.9, + } + ] + + words = [WordPrediction(**word) for word in words] + + words, direction, checklist = extract_words_within_element(words, element) + + assert words is None + assert direction is None + assert checklist == [False] + + paragraph = { + "box": [0, 0, 5, 5], + "contents": "This is a test.", + "direction": "horizontal", + "order": 1, + "role": None, + } + + element = ParagraphSchema(**paragraph) + + words = [ + { + "points": [[0, 0], [1, 0], [1, 1], [0, 1]], + "content": "Hello", + "direction": "horizontal", + "rec_score": 0.9, + "det_score": 0.9, + }, + { + "points": [[0, 1], [1, 1], [1, 2], [0, 2]], + "content": "World", + "direction": "horizontal", + "rec_score": 0.9, + "det_score": 0.9, + }, + ] + + words = [WordPrediction(**word) for word in words] + + words, direction, checklist = extract_words_within_element(words, element) + + assert words == "Hello\nWorld" + assert direction == "horizontal" + assert checklist == [True, True] + + paragraph = { + "box": [0, 0, 5, 5], + "contents": "This is a test.", + "direction": "horizontal", + "order": 1, + "role": None, + } + + element = ParagraphSchema(**paragraph) + + words = [ + { + "points": [[2, 0], [3, 0], [3, 1], [2, 1]], + "content": "Hello", + "direction": "vertical", + "rec_score": 0.9, + "det_score": 0.9, + }, + { + "points": [[0, 1], [1, 1], [1, 2], [0, 2]], + "content": "World", + "direction": "vertical", + "rec_score": 0.9, + "det_score": 0.9, + }, + ] + + words = [WordPrediction(**word) for word in words] + + words, direction, checklist = extract_words_within_element(words, element) + + assert words == "Hello\nWorld" + assert direction == "vertical" + assert checklist == [True, True] + + +def test_is_vertical(): + quad = [[0, 0], [1, 0], [1, 1], [0, 1]] + assert not is_vertical(quad) + quad = [[0, 0], [1, 0], [1, 3], [0, 3]] + assert is_vertical(quad) + + +def test_is_noise(): + quad = [[0, 0], [1, 0], [1, 1], [0, 1]] + assert is_noise(quad) + + quad = [[0, 0], [20, 0], [20, 20], [0, 20]] + assert not is_noise(quad) + + +def test_recursive_update(): + original = {"a": {"b": {"c": 1, "d": 2}}} + update = {"a": {"b": {"d": 3, "e": 4}}} + + updated = recursive_update(original, update) + + assert updated == {"a": {"b": {"c": 1, "d": 3, "e": 4}}} + + +def test_extract_words_within_table(): + points = [ + [[0, 0], [3, 0], [3, 1], [0, 1]], + [[3, 0], [5, 0], [5, 1], [3, 1]], + [[0, 1], [1, 1], [1, 4], [0, 4]], + [[3, 1], [3, 1], [4, 4], [4, 4]], + ] + + scores = [0.9, 0.9, 0.9, 0.9] + + words = TextDetectorSchema(points=points, scores=scores) + + table = { + "box": [0, 0, 3, 3], + "n_row": 2, + "n_col": 2, + "rows": [], + "cols": [], + "cells": [], + "order": 0, + } + + table = TableStructureRecognizerSchema(**table) + checklist = [False, False, False, False] + h_words, v_words, checklist = _extract_words_within_table(words, table, checklist) + + assert len(h_words) == 1 + assert len(v_words) == 1 + assert checklist == [True, False, True, False] + + +def test_calc_overlap_words_on_lines(): + lines = [ + { + "box": [0, 0, 2, 1], + "score": 0.9, + }, + { + "box": [0, 1, 1, 1], + "score": 0.9, + }, + ] + + lines = [TableLineSchema(**line) for line in lines] + + words = [ + { + "points": [[0, 0], [1, 0], [1, 1], [0, 1]], + }, + { + "points": [[1, 0], [3, 0], [3, 1], [1, 1]], + }, + ] + + overrap_ratios = _calc_overlap_words_on_lines(lines, words) + + assert overrap_ratios == [[1.0, 0.0], [0.5, 0.0]] + + +def test_correct_vertical_word_boxes(): + words = [ + { + "points": [[0, 0], [20, 0], [20, 100], [0, 100]], + "score": 0.9, + }, + ] + + cols = [TableLineSchema(box=[0, 0, 20, 100], score=0.9)] + rows = [ + TableLineSchema(box=[0, 0, 20, 50], score=0.9), + TableLineSchema(box=[0, 50, 20, 100], score=0.9), + ] + + cells = [ + { + "col": 1, + "row": 1, + "col_span": 1, + "row_span": 1, + "box": [0, 0, 20, 50], + "contents": None, + }, + { + "col": 1, + "row": 2, + "col_span": 1, + "row_span": 1, + "box": [0, 50, 20, 100], + "contents": None, + }, + ] + + cells = [TableCellSchema(**cell) for cell in cells] + + table = { + "box": [0, 0, 100, 20], + "n_row": 2, + "n_col": 1, + "rows": rows, + "cols": cols, + "cells": cells, + "order": 0, + } + + table = TableStructureRecognizerSchema(**table) + + overrap_ratios = _calc_overlap_words_on_lines(cols, words) + + points, scores = _correct_vertical_word_boxes( + overrap_ratios, + table, + words, + ) + + assert len(points) == 2 + assert len(scores) == 2 + assert points[0] == [[0, 0], [20, 0], [20, 50], [0, 50]] + assert points[1] == [[0, 50], [20, 50], [20, 100], [0, 100]] + + +def test_correct_horizontal_word_boxes(): + words = [ + { + "points": [[0, 0], [100, 0], [100, 20], [0, 20]], + "score": 0.9, + }, + ] + + cols = [ + TableLineSchema(box=[0, 0, 50, 20], score=0.9), + TableLineSchema(box=[50, 0, 100, 20], score=0.9), + ] + rows = [ + TableLineSchema(box=[0, 0, 100, 20], score=0.9), + ] + + cells = [ + { + "col": 1, + "row": 1, + "col_span": 1, + "row_span": 1, + "box": [0, 0, 50, 20], + "contents": None, + }, + { + "col": 2, + "row": 1, + "col_span": 1, + "row_span": 1, + "box": [50, 0, 100, 20], + "contents": None, + }, + ] + + cells = [TableCellSchema(**cell) for cell in cells] + + table = { + "box": [0, 0, 20, 100], + "n_row": 2, + "n_col": 1, + "rows": rows, + "cols": cols, + "cells": cells, + "order": 0, + } + + table = TableStructureRecognizerSchema(**table) + + overrap_ratios = _calc_overlap_words_on_lines(cols, words) + + points, scores = _correct_horizontal_word_boxes( + overrap_ratios, + table, + words, + ) + + assert len(points) == 2 + assert len(scores) == 2 + assert points[0] == [[0, 0], [50, 0], [50, 20], [0, 20]] + assert points[1] == [[50, 0], [100, 0], [100, 20], [50, 20]] + + +def test_split_text_across_cells(): + points = [ + [[0, 0], [100, 0], [100, 20], [0, 20]], + ] + + scores = [0.9] + + words = TextDetectorSchema(points=points, scores=scores) + + cols = [ + TableLineSchema(box=[0, 0, 50, 20], score=0.9), + TableLineSchema(box=[50, 0, 100, 20], score=0.9), + ] + rows = [ + TableLineSchema(box=[0, 0, 100, 20], score=0.9), + ] + + cells = [ + { + "col": 1, + "row": 1, + "col_span": 1, + "row_span": 1, + "box": [0, 0, 50, 20], + "contents": None, + }, + { + "col": 2, + "row": 1, + "col_span": 1, + "row_span": 1, + "box": [50, 0, 100, 20], + "contents": None, + }, + ] + + cells = [TableCellSchema(**cell) for cell in cells] + + table = { + "box": [0, 0, 100, 20], + "n_row": 2, + "n_col": 1, + "rows": rows, + "cols": cols, + "cells": cells, + "order": 0, + } + + table = TableStructureRecognizerSchema(**table) + + Layout = DocumentAnalyzerSchema( + paragraphs=[], + figures=[], + tables=[table], + words=[], + ) + + results = _split_text_across_cells(words, Layout) + + assert len(results.points) == 2 + assert len(results.scores) == 2 + assert results.points[0] == [[0, 0], [50, 0], [50, 20], [0, 20]] + assert results.points[1] == [[50, 0], [100, 0], [100, 20], [50, 20]]