Skip to content

Commit

Permalink
Merge pull request #58 from kotaro-kinoshita/feature/convert-onnx
Browse files Browse the repository at this point in the history
Feature/onnx inference
  • Loading branch information
kotaro-kinoshita authored Dec 15, 2024
2 parents 513a7b3 + 90e9b4d commit 03f482f
Show file tree
Hide file tree
Showing 16 changed files with 395 additions and 270 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ jobs:
- name: Run linter
run: tox -e lint
- name: Run tests
run: tox -p -e py39,py310,py311,py312
run: tox -p -e py310,py311,py312
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ dataset/
weights/
results/

.coverage*
.coverage*

*.onnx
2 changes: 1 addition & 1 deletion .python-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.9
3.10
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

<img src="static/logo/horizontal.png" width="800px">

![Python](https://img.shields.io/badge/Python-3.9|3.10|3.11|3.12-F9DC3E.svg?logo=python&logoColor=&style=flat)
![Python](https://img.shields.io/badge/Python-3.10|3.11|3.12-F9DC3E.svg?logo=python&logoColor=&style=flat)
![Pytorch](https://img.shields.io/badge/Pytorch-2.5-EE4C2C.svg?logo=Pytorch&style=fla)
![CUDA](https://img.shields.io/badge/CUDA->=11.8-76B900.svg?logo=NVIDIA&style=fla)
![OS](https://img.shields.io/badge/OS-Linux|Mac|Win-1793D1.svg?&style=fla)
Expand Down Expand Up @@ -61,7 +61,7 @@ yomitoku ${path_data} -f md -o results -v --figure --lite
- `-f`, `--format` 出力形式のファイルフォーマットを指定します。(json, csv, html, md をサポート)
- `-o`, `--outdir` 出力先のディレクトリ名を指定します。存在しない場合は新規で作成されます。
- `-v`, `--vis` を指定すると解析結果を可視化した画像を出力します。
- `-l`, `--lite` を指定すると軽量モデルで推論を実行します。CPUでも高速に推論可能です
- `-l`, `--lite` を指定すると軽量モデルで推論を実行します。通常より高速に推論できますが、若干、精度が低下する可能性があります
- `-d`, `--device` モデルを実行するためのデバイスを指定します。gpu が利用できない場合は cpu で推論が実行されます。(デフォルト: cuda)
- `--ignore_line_break` 画像の改行位置を無視して、段落内の文章を連結して返します。(デフォルト:画像通りの改行位置位置で改行します。)
- `--figure_letter` 検出した図表に含まれる文字も出力ファイルにエクスポートします。
Expand Down
2 changes: 1 addition & 1 deletion README_EN.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

<img src="static/logo/horizontal.png" width="800px">

![Python](https://img.shields.io/badge/Python-3.9|3.10|3.11|3.12-F9DC3E.svg?logo=python&logoColor=&style=flat)
![Python](https://img.shields.io/badge/Python-3.10|3.11|3.12-F9DC3E.svg?logo=python&logoColor=&style=flat)
![Pytorch](https://img.shields.io/badge/Pytorch-2.5-EE4C2C.svg?logo=Pytorch&style=fla)
![CUDA](https://img.shields.io/badge/CUDA->=11.8-76B900.svg?logo=NVIDIA&style=fla)
![OS](https://img.shields.io/badge/OS-Linux|Mac|Win-1793D1.svg?&style=fla)
Expand Down
2 changes: 1 addition & 1 deletion docs/installation.en.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Installation


This package requires Python 3.9 or later and PyTorch 2.5 or later for execution. PyTorch must be installed according to your CUDA version. A GPU with more than 8GB of VRAM is recommended. While it can run on a CPU, please note that the processing is not currently optimized for CPUs, which may result in longer execution times.
This package requires Python 3.10 or later and PyTorch 2.5 or later for execution. PyTorch must be installed according to your CUDA version. A GPU with more than 8GB of VRAM is recommended. While it can run on a CPU, please note that the processing is not currently optimized for CPUs, which may result in longer execution times.

## from PYPI

Expand Down
2 changes: 1 addition & 1 deletion docs/installation.ja.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Installation

本パッケージは Python3.9+, Pytorch が実行に必要です。Pytorch はご自身の環境に合わせて、インストールが必要です。計算機は GPU(> VRAM 8G)を推奨しています。CPU でも動作しますが、現在、CPU 向けに処理が最適化されておらず、実行に時間がかかりますのでご注意ください。
本パッケージは Python3.10+, Pytorch が実行に必要です。Pytorch はご自身の環境に合わせて、インストールが必要です。計算機は GPU(> VRAM 8G)を推奨しています。CPU でも動作しますが、現在、CPU 向けに処理が最適化されておらず、実行に時間がかかりますのでご注意ください。

## PYPI からインストール

Expand Down
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ authors = [{name = "Kotaro Kinoshita", email = "[email protected]"}]
description = "Yomitoku is an AI-powered document image analysis package designed specifically for the Japanese language."
readme = "README.md"
license = {text = "CC BY-NC-SA 4.0"}
requires-python = ">=3.9,<3.13"
requires-python = ">=3.10,<3.13"
keywords = ["Japanese", "OCR", "Deep Learning"]
dependencies = [
"huggingface-hub>=0.26.1",
Expand All @@ -26,6 +26,9 @@ dependencies = [
"torchvision>=0.20.0",
"torch>=2.5.0",
"pypdfium2>=4.30.0",
"onnx>=1.17.0",
"onnxruntime>=1.20.1",
"onnxruntime-gpu>=1.20.1",
]

[tool.uv-dynamic-versioning]
Expand Down Expand Up @@ -71,7 +74,7 @@ yomitoku = "yomitoku.cli.main:main"
[tool.tox]
legacy_tox_ini = """
[tox]
envlist = lint, py39, py310, py311, py312, docs
envlist = lint, py310, py311, py312, docs
[testenv]
deps = pytest
Expand Down
6 changes: 6 additions & 0 deletions src/yomitoku/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,12 @@ def main():

if args.lite:
configs["ocr"]["text_recognizer"]["model_name"] = "parseq-small"
configs["ocr"]["text_detector"]["infer_onnx"] = True

# Note: Text Detector以外はONNX推論よりもPyTorch推論の方が速いため、ONNX推論は行わない
# configs["ocr"]["text_recognizer"]["infer_onnx"] = True
# configs["layout_analyzer"]["table_structure_recognizer"]["infer_onnx"] = True
# configs["layout_analyzer"]["layout_parser"]["infer_onnx"] = True

analyzer = DocumentAnalyzer(
configs=configs,
Expand Down
56 changes: 53 additions & 3 deletions src/yomitoku/layout_parser.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from typing import List, Union

import cv2
import os
import onnx
import onnxruntime
import torch
import torchvision.transforms as T
from PIL import Image
from pydantic import conlist

from .constants import ROOT_DIR

from .base import BaseModelCatalog, BaseModule, BaseSchema
from .configs import LayoutParserRTDETRv2Config
from .models import RTDETRv2
Expand Down Expand Up @@ -91,6 +96,7 @@ def __init__(
device="cuda",
visualize=False,
from_pretrained=True,
infer_onnx=False,
):
super().__init__()
self.load_model(model_name, path_cfg, from_pretrained)
Expand Down Expand Up @@ -119,11 +125,44 @@ def __init__(
}

self.role = self._cfg.role
self.infer_onnx = infer_onnx
if infer_onnx:
name = self._cfg.hf_hub_repo.split("/")[-1]
path_onnx = f"{ROOT_DIR}/onnx/{name}.onnx"
if not os.path.exists(path_onnx):
self.convert_onnx(path_onnx)

model = onnx.load(path_onnx)
if torch.cuda.is_available() and device == "cuda":
self.sess = onnxruntime.InferenceSession(
model.SerializeToString(), providers=["CUDAExecutionProvider"]
)
else:
self.sess = onnxruntime.InferenceSession(model.SerializeToString())

def convert_onnx(self, path_onnx):
dynamic_axes = {
"input": {0: "batch_size"},
"output": {0: "batch_size"},
}

img_size = self._cfg.data.img_size
dummy_input = torch.randn(1, 3, *img_size, requires_grad=True)

torch.onnx.export(
self.model,
dummy_input,
path_onnx,
opset_version=16,
input_names=["input"],
output_names=["pred_logits", "pred_boxes"],
dynamic_axes=dynamic_axes,
)

def preprocess(self, img):
cv_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = Image.fromarray(cv_img)
img_tensor = self.transforms(img)[None].to(self.device)
img_tensor = self.transforms(img)[None]
return img_tensor

def postprocess(self, preds, image_size):
Expand Down Expand Up @@ -175,8 +214,19 @@ def __call__(self, img):
ori_h, ori_w = img.shape[:2]
img_tensor = self.preprocess(img)

with torch.inference_mode():
preds = self.model(img_tensor)
if self.infer_onnx:
input = img_tensor.numpy()
results = self.sess.run(None, {"input": input})
preds = {
"pred_logits": torch.tensor(results[0]).to(self.device),
"pred_boxes": torch.tensor(results[1]).to(self.device),
}

else:
with torch.inference_mode():
img_tensor = img_tensor.to(self.device)
preds = self.model(img_tensor)

results = self.postprocess(preds, (ori_h, ori_w))

vis = None
Expand Down
18 changes: 9 additions & 9 deletions src/yomitoku/models/parseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from timm.models.helpers import named_apply
from torch import Tensor

from ..postprocessor import ParseqTokenizer as Tokenizer
from .layers.parseq_transformer import Decoder, Encoder, TokenEmbedding


Expand Down Expand Up @@ -123,7 +122,6 @@ def decode(

def forward(
self,
tokenizer: Tokenizer,
images: Tensor,
max_length: Optional[int] = None,
) -> Tensor:
Expand All @@ -150,11 +148,11 @@ def forward(
if self.decode_ar:
tgt_in = torch.full(
(bs, num_steps),
tokenizer.pad_id,
self.tokenizer.pad_id,
dtype=torch.long,
device=self._device,
)
tgt_in[:, 0] = tokenizer.bos_id
tgt_in[:, 0] = self.tokenizer.bos_id

logits = []
for i in range(num_steps):
Expand All @@ -177,15 +175,15 @@ def forward(
# greedy decode. add the next token index to the target input
tgt_in[:, j] = p_i.squeeze().argmax(-1)
# Efficient batch decoding: If all output words have at least one EOS token, end decoding.
if testing and (tgt_in == tokenizer.eos_id).any(dim=-1).all():
if testing and (tgt_in == self.tokenizer.eos_id).any(dim=-1).all():
break

logits = torch.cat(logits, dim=1)
else:
# No prior context, so input is just <bos>. We query all positions.
tgt_in = torch.full(
(bs, 1),
tokenizer.bos_id,
self.tokenizer.bos_id,
dtype=torch.long,
device=self._device,
)
Expand All @@ -200,23 +198,25 @@ def forward(
torch.ones(
num_steps,
num_steps,
dtype=torch.bool,
dtype=torch.int64,
device=self._device,
),
2,
)
] = 0
bos = torch.full(
(bs, 1),
tokenizer.bos_id,
self.tokenizer.bos_id,
dtype=torch.long,
device=self._device,
)
for i in range(self.refine_iters):
# Prior context is the previous output.
tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1)
# Mask tokens beyond the first EOS token.
tgt_padding_mask = (tgt_in == tokenizer.eos_id).int().cumsum(-1) > 0
tgt_padding_mask = (tgt_in == self.tokenizer.eos_id).int().cumsum(
-1
) > 0
tgt_out = self.decode(
tgt_in,
memory,
Expand Down
Empty file added src/yomitoku/onnx/.gitkeep
Empty file.
59 changes: 55 additions & 4 deletions src/yomitoku/table_structure_recognizer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from typing import List, Union

import cv2
import os
import onnx
import onnxruntime
import torch
import torchvision.transforms as T
from PIL import Image
from pydantic import conlist

from .constants import ROOT_DIR

from .base import BaseModelCatalog, BaseModule, BaseSchema
from .configs import TableStructureRecognizerRTDETRv2Config
from .layout_parser import filter_contained_rectangles_within_category
Expand Down Expand Up @@ -109,12 +114,13 @@ def __init__(
device="cuda",
visualize=False,
from_pretrained=True,
infer_onnx=False,
):
super().__init__()
self.load_model(
model_name,
path_cfg,
from_pretrained=True,
from_pretrained=from_pretrained,
)
self.device = device
self.visualize = visualize
Expand Down Expand Up @@ -142,6 +148,40 @@ def __init__(
id: category for id, category in enumerate(self._cfg.category)
}

self.infer_onnx = infer_onnx
if infer_onnx:
name = self._cfg.hf_hub_repo.split("/")[-1]
path_onnx = f"{ROOT_DIR}/onnx/{name}.onnx"
if not os.path.exists(path_onnx):
self.convert_onnx(path_onnx)

model = onnx.load(path_onnx)
if torch.cuda.is_available() and device == "cuda":
self.sess = onnxruntime.InferenceSession(
model.SerializeToString(), providers=["CUDAExecutionProvider"]
)
else:
self.sess = onnxruntime.InferenceSession(model.SerializeToString())

def convert_onnx(self, path_onnx):
dynamic_axes = {
"input": {0: "batch_size"},
"output": {0: "batch_size"},
}

img_size = self._cfg.data.img_size
dummy_input = torch.randn(1, 3, *img_size, requires_grad=True)

torch.onnx.export(
self.model,
dummy_input,
path_onnx,
opset_version=16,
input_names=["input"],
output_names=["pred_logits", "pred_boxes"],
dynamic_axes=dynamic_axes,
)

def preprocess(self, img, boxes):
cv_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

Expand All @@ -151,7 +191,7 @@ def preprocess(self, img, boxes):
table_img = cv_img[y1:y2, x1:x2, :]
th, hw = table_img.shape[:2]
table_img = Image.fromarray(table_img)
img_tensor = self.transforms(table_img)[None].to(self.device)
img_tensor = self.transforms(table_img)[None]
table_imgs.append(
{
"tensor": img_tensor,
Expand Down Expand Up @@ -228,8 +268,19 @@ def __call__(self, img, table_boxes, vis=None):
img_tensors = self.preprocess(img, table_boxes)
outputs = []
for data in img_tensors:
with torch.inference_mode():
pred = self.model(data["tensor"])
if self.infer_onnx:
input = data["tensor"].numpy()
results = self.sess.run(None, {"input": input})
pred = {
"pred_logits": torch.tensor(results[0]).to(self.device),
"pred_boxes": torch.tensor(results[1]).to(self.device),
}

else:
with torch.inference_mode():
data["tensor"] = data["tensor"].to(self.device)
pred = self.model(data["tensor"])

table = self.postprocess(pred, data)
outputs.append(table)

Expand Down
Loading

0 comments on commit 03f482f

Please sign in to comment.