Skip to content

Commit

Permalink
Merge pull request #31 from kotaro-kinoshita/feature/reading-ordar-al…
Browse files Browse the repository at this point in the history
…gorithm

Feature/reading ordar algorithm
  • Loading branch information
kotaro-kinoshita authored Nov 19, 2024
2 parents 07835f1 + 0678df4 commit 8502d31
Show file tree
Hide file tree
Showing 12 changed files with 406 additions and 104 deletions.
14 changes: 3 additions & 11 deletions src/yomitoku/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,7 @@ def process_single_file(args, analyzer, path, format):
cv2.imwrite(out_path, layout)
logger.info(f"Output file: {out_path}")

# cv2.imwrite(
# os.path.join(args.outdir, f"{dirname}_{filename}_p{page+1}.jpg"),
# img,
# )

out_path = os.path.join(
args.outdir, f"{dirname}_{filename}_p{page+1}.{format}"
)
out_path = os.path.join(args.outdir, f"{dirname}_{filename}_p{page+1}.{format}")

if format == "json":
results.to_json(
Expand All @@ -63,6 +56,7 @@ def process_single_file(args, analyzer, path, format):
results.to_html(
out_path,
ignore_line_break=args.ignore_line_break,
img=img,
)
elif format == "md":
results.to_markdown(
Expand All @@ -75,9 +69,7 @@ def process_single_file(args, analyzer, path, format):

def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"arg1", type=str, help="path of target image file or directory"
)
parser.add_argument("arg1", type=str, help="path of target image file or directory")
parser.add_argument(
"-f",
"--format",
Expand Down
54 changes: 43 additions & 11 deletions src/yomitoku/document_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,16 @@
from .ocr import OCR, WordPrediction
from .table_structure_recognizer import TableStructureRecognizerSchema
from .utils.misc import is_contained, quad_to_xyxy
from .reading_order import prediction_reading_order

from .utils.visualizer import reading_order_visualizer


class ParagraphSchema(BaseSchema):
box: conlist(int, min_length=4, max_length=4)
contents: Union[str, None]
direction: Union[str, None]
order: Union[int, None]


class DocumentAnalyzerSchema(BaseSchema):
Expand All @@ -39,14 +43,34 @@ def combine_flags(flag1, flag2):
return [f1 or f2 for f1, f2 in zip(flag1, flag2)]


def judge_page_direction(paragraphs):
h_sum_area = 0
v_sum_area = 0

for paragraph in paragraphs:
x1, y1, x2, y2 = paragraph.box
w = x2 - x1
h = y2 - y1

if paragraph.direction == "horizontal":
h_sum_area += w * h
else:
v_sum_area += w * h

if h_sum_area > v_sum_area:
return "horizontal"

return "vertical"


def extract_words_within_element(pred_words, element):
contained_words = []
word_sum_width = 0
word_sum_height = 0
check_list = [False] * len(pred_words)
for i, word in enumerate(pred_words):
word_box = quad_to_xyxy(word.points)
if is_contained(element.box, word_box, threshold=0.6):
if is_contained(element.box, word_box, threshold=0.5):
contained_words.append(word)
word_sum_width += word_box[2] - word_box[0]
word_sum_height += word_box[3] - word_box[1]
Expand All @@ -62,9 +86,7 @@ def extract_words_within_element(pred_words, element):
cnt_horizontal = word_direction.count("horizontal")
cnt_vertical = word_direction.count("vertical")

element_direction = (
"horizontal" if cnt_horizontal > cnt_vertical else "vertical"
)
element_direction = "horizontal" if cnt_horizontal > cnt_vertical else "vertical"
if element_direction == "horizontal":
contained_words = sorted(
contained_words,
Expand All @@ -83,9 +105,7 @@ def extract_words_within_element(pred_words, element):
reverse=True,
)

contained_words = "\n".join(
[content.content for content in contained_words]
)
contained_words = "\n".join([content.content for content in contained_words])
return (contained_words, element_direction, check_list)


Expand Down Expand Up @@ -137,9 +157,8 @@ def __init__(self, configs=None, device="cuda", visualize=False):
)

self.ocr = OCR(configs=default_configs["ocr"])
self.layout = LayoutAnalyzer(
configs=default_configs["layout_analyzer"]
)
self.layout = LayoutAnalyzer(configs=default_configs["layout_analyzer"])
self.visualize = visualize

def aggregate(self, ocr_res, layout_res):
paragraphs = []
Expand Down Expand Up @@ -168,7 +187,9 @@ def aggregate(self, ocr_res, layout_res):
"contents": words,
"box": paragraph.box,
"direction": direction,
"order": 0,
}

check_list = combine_flags(check_list, flags)
paragraph = ParagraphSchema(**paragraph)
paragraphs.append(paragraph)
Expand All @@ -180,11 +201,16 @@ def aggregate(self, ocr_res, layout_res):
"contents": word.content,
"box": quad_to_xyxy(word.points),
"direction": direction,
"order": 0,
}

paragraph = ParagraphSchema(**paragraph)
paragraphs.append(paragraph)

page_direction = judge_page_direction(paragraphs)
elements = paragraphs + layout_res.tables
prediction_reading_order(elements, page_direction)

outputs = {
"paragraphs": paragraphs,
"tables": layout_res.tables,
Expand Down Expand Up @@ -212,4 +238,10 @@ async def run(self, img):
return results, ocr, layout

def __call__(self, img):
return asyncio.run(self.run(img))
self.img = img
resutls, ocr, layout = asyncio.run(self.run(img))

if self.visualize:
layout = reading_order_visualizer(layout, resutls)

return resutls, ocr, layout
7 changes: 3 additions & 4 deletions src/yomitoku/export/export_csv.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import csv

from .utils import sort_elements


def table_to_csv(table, ignore_line_break):
num_rows = table.n_row
Expand Down Expand Up @@ -45,6 +43,7 @@ def export_csv(inputs, out_path: str, ignore_line_break: bool = False):
"type": "table",
"box": table.box,
"element": table_csv,
"order": table.order,
}
)

Expand All @@ -55,11 +54,11 @@ def export_csv(inputs, out_path: str, ignore_line_break: bool = False):
"type": "paragraph",
"box": paraghraph.box,
"element": contents,
"order": paraghraph.order,
}
)

directions = [paraghraph.direction for paraghraph in inputs.paragraphs]
elements = sort_elements(elements, directions)
elements = sorted(elements, key=lambda x: x["order"])

with open(out_path, "w", newline="", encoding="utf-8") as f:
writer = csv.writer(f, quoting=csv.QUOTE_MINIMAL)
Expand Down
16 changes: 7 additions & 9 deletions src/yomitoku/export/export_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

from lxml import etree, html

from .utils import sort_elements


def convert_text_to_html(text):
"""
Expand Down Expand Up @@ -72,6 +70,7 @@ def table_to_html(table, ignore_line_break):

return {
"box": table.box,
"order": table.order,
"html": table_html,
}

Expand All @@ -87,6 +86,7 @@ def paragraph_to_html(paragraph, ignore_line_break):

return {
"box": paragraph.box,
"order": paragraph.order,
"html": add_p_tag(contents),
}

Expand All @@ -95,25 +95,23 @@ def export_html(
inputs,
out_path: str,
ignore_line_break: bool = False,
img=None,
):
html_string = ""
elements = []
for table in inputs.tables:
elements.append(table_to_html(table, ignore_line_break))

for paraghraph in inputs.paragraphs:
elements.append(paragraph_to_html(paraghraph, ignore_line_break))
for paragraph in inputs.paragraphs:
elements.append(paragraph_to_html(paragraph, ignore_line_break))

directions = [paraghraph.direction for paraghraph in inputs.paragraphs]
elements = sort_elements(elements, directions)
elements = sorted(elements, key=lambda x: x["order"])

html_string = "".join([element["html"] for element in elements])
html_string = add_html_tag(html_string)

parsed_html = html.fromstring(html_string)
formatted_html = etree.tostring(
parsed_html, pretty_print=True, encoding="unicode"
)
formatted_html = etree.tostring(parsed_html, pretty_print=True, encoding="unicode")

with open(out_path, "w", encoding="utf-8") as f:
f.write(formatted_html)
16 changes: 7 additions & 9 deletions src/yomitoku/export/export_markdown.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import re

from .utils import sort_elements


def escape_markdown_special_chars(text):
special_chars = r"([`*_{}[\]()#+.!|-])"
Expand All @@ -17,6 +15,7 @@ def paragraph_to_md(paragraph, ignore_line_break):
contents = contents.replace("\n", "<br>")

return {
"order": paragraph.order,
"box": paragraph.box,
"md": contents + "\n",
}
Expand Down Expand Up @@ -56,6 +55,7 @@ def table_to_md(table, ignore_line_break):
table_md += f"|{header}|\n"

return {
"order": table.order,
"box": table.box,
"md": table_md,
}
Expand All @@ -66,13 +66,11 @@ def export_markdown(inputs, out_path: str, ignore_line_break: bool = False):
for table in inputs.tables:
elements.append(table_to_md(table, ignore_line_break))

for paraghraph in inputs.paragraphs:
elements.append(paragraph_to_md(paraghraph, ignore_line_break))

directions = [paraghraph.direction for paraghraph in inputs.paragraphs]
sort_elements(elements, directions)
for paragraph in inputs.paragraphs:
elements.append(paragraph_to_md(paragraph, ignore_line_break))

markdonw = "\n".join([element["md"] for element in elements])
elements = sorted(elements, key=lambda x: x["order"])
markdown = "\n".join([element["md"] for element in elements])

with open(out_path, "w", encoding="utf-8") as f:
f.write(markdonw)
f.write(markdown)
4 changes: 0 additions & 4 deletions src/yomitoku/export/utils.py

This file was deleted.

Loading

0 comments on commit 8502d31

Please sign in to comment.