diff --git a/README_zh-CN.md b/README_zh-CN.md index c38839637..626669ab1 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -232,10 +232,10 @@ MMOCR 是一款由来自不同高校和企业的研发人员共同参与贡献 ## 欢迎加入 OpenMMLab 社区 -扫描下方的二维码可关注 OpenMMLab 团队的 [知乎官方账号](https://www.zhihu.com/people/openmmlab),加入 OpenMMLab 团队的 [官方交流 QQ 群](https://r.vansin.top/?r=join-qq),或通过添加微信“Open小喵Lab”加入官方交流微信群。 +扫描下方的二维码可关注 OpenMMLab 团队的 知乎官方账号,扫描下方微信二维码添加喵喵好友,进入 MMOCR 微信交流社群。【加好友申请格式:研究方向+地区+学校/公司+姓名】
- +
我们会在 OpenMMLab 社区为大家 diff --git a/mmocr/models/textdet/detectors/mmdet_wrapper.py b/mmocr/models/textdet/detectors/mmdet_wrapper.py index 1d6be8caa..49de8482c 100644 --- a/mmocr/models/textdet/detectors/mmdet_wrapper.py +++ b/mmocr/models/textdet/detectors/mmdet_wrapper.py @@ -138,6 +138,7 @@ def adapt_predictions(self, data: MMDET_SampleList, # convert by text_repr_type if self.text_repr_type == 'quad': for j, poly in enumerate(filterd_polygons): + poly = poly.reshape(-1, 2) rect = cv2.minAreaRect(poly) vertices = cv2.boxPoints(rect) poly = vertices.flatten() diff --git a/projects/Donut/.gitignore b/projects/Donut/.gitignore new file mode 100644 index 000000000..62e2e8f44 --- /dev/null +++ b/projects/Donut/.gitignore @@ -0,0 +1,2 @@ +/datasets +/data diff --git a/projects/Donut/README.md b/projects/Donut/README.md new file mode 100644 index 000000000..a96c6ea2d --- /dev/null +++ b/projects/Donut/README.md @@ -0,0 +1,134 @@ +# Donut + +## Description + +This is an reimplementation of Donut official repo https://github.com/clovaai/donut. + +## Usage + +### Prerequisites + +- Python 3.7 +- PyTorch 1.6 or higher +- [MIM](https://github.com/open-mmlab/mim) +- [MMOCR](https://github.com/open-mmlab/mmocr) +- transformers 4.25.1 + +All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `Donut/` root directory, run the following line to add the current directory to `PYTHONPATH`: + +```shell +# Linux +export PYTHONPATH=`pwd`:$PYTHONPATH +# Windows PowerShell +$env:PYTHONPATH=Get-Location +``` + +### Training commands + +In the current directory, run the following command to train the model: + +```bash +mim train mmocr configs/donut_cord_30e.py --work-dir work_dirs/donut_cord_30e/ +``` + +To train on multiple GPUs, e.g. 8 GPUs, run the following command: + +```bash +mim train mmocr configs/donut_cord_30e.py --work-dir work_dirs/donut_cord_30e/ --launcher pytorch --gpus 8 +``` + +### Testing commands + +Before test, you need change tokenizer_cfg in config. The checkpoint shuold be the model save dir, like `work_dirs/donut_cord_30e/`. +In the current directory, run the following command to test the model: + +```bash +mim test mmocr configs/donut_cord_30e.py --work-dir work_dirs/donut_cord_30e/ --checkpoint ${CHECKPOINT_PATH} +``` + +## Results + +> List the results as usually done in other model's README. [Example](https://github.com/open-mmlab/mmocr/blob/1.x/configs/textdet/dbnet/README.md#results-and-models) +> +> You should claim whether this is based on the pre-trained weights, which are converted from the official release; or it's a reproduced result obtained from retraining the model in this project. + +| Method | Pretrained Model | Training set | Test set | #epoch | Test size | TED Acc | F1 | Download | +| :-------------------------------------: | :-----------------------: | :-----------: | :----------: | :----: | :-------: | :-----: | :----: | :----------------------: | +| [Donut_CORD](configs/donut_cord_30e.py) | naver-clova-ix/donut-base | cord-v2 Train | cord-v2 Test | 30 | 736 | 0.8977 | 0.8279 | [model](<>) \| [log](<>) | + +## Citation + + + +```bibtex +@article{Kim_Hong_Yim_Nam_Park_Yim_Hwang_Yun_Han_Park_2021, +title={OCR-free Document Understanding Transformer}, +DOI={10.48550/arxiv.2111.15664}, +author={Kim, Geewook and Hong, Teakgyu and Yim, Moonbin and Nam, Jeongyeon and Park, Jinyoung and Yim, Jinyeong and Hwang, Wonseok and Yun, Sangdoo and Han, Dongyoon and Park, Seunghyun}, +year={2021}, +month={Nov}, +language={en-US} +} +``` + + + +## Checklist + +Here is a checklist illustrating a usual development workflow of a successful project, and also serves as an overview of this project's progress. + +> The PIC (person in charge) or contributors of this project should check all the items that they believe have been finished, which will further be verified by codebase maintainers via a PR. +> +> OpenMMLab's maintainer will review the code to ensure the project's quality. Reaching the first milestone means that this project suffices the minimum requirement of being merged into 'projects/'. But this project is only eligible to become a part of the core package upon attaining the last milestone. +> +> Note that keeping this section up-to-date is crucial not only for this project's developers but the entire community, since there might be some other contributors joining this project and deciding their starting point from this list. It also helps maintainers accurately estimate time and effort on further code polishing, if needed. +> +> A project does not necessarily have to be finished in a single PR, but it's essential for the project to at least reach the first milestone in its very first PR. + +- [ ] Milestone 1: PR-ready, and acceptable to be one of the `projects/`. + + - [x] Finish the code + + > The code's design shall follow existing interfaces and convention. For example, each model component should be registered into `mmocr.registry.MODELS` and configurable via a config file. + + - [x] Basic docstrings & proper citation + + > Each major object should contain a docstring, describing its functionality and arguments. If you have adapted the code from other open-source projects, don't forget to cite the source project in docstring and make sure your behavior is not against its license. Typically, we do not accept any code snippet under GPL license. [A Short Guide to Open Source Licenses](https://medium.com/nationwide-technology/a-short-guide-to-open-source-licenses-cf5b1c329edd) + + - [ ] Test-time correctness + + > If you are reproducing the result from a paper, make sure your model's inference-time performance matches that in the original paper. The weights usually could be obtained by simply renaming the keys in the official pre-trained weights. This test could be skipped though, if you are able to prove the training-time correctness and check the second milestone. + + - [x] A full README + + > As this template does. + +- [x] Milestone 2: Indicates a successful model implementation. + + - [x] Training-time correctness + + > If you are reproducing the result from a paper, checking this item means that you should have trained your model from scratch based on the original paper's specification and verified that the final result matches the report within a minor error range. + +- [ ] Milestone 3: Good to be a part of our core package! + + - [ ] Type hints and docstrings + + > Ideally *all* the methods should have [type hints](https://www.pythontutorial.net/python-basics/python-type-hints/) and [docstrings](https://google.github.io/styleguide/pyguide.html#381-docstrings). [Example](https://github.com/open-mmlab/mmocr/blob/76637a290507f151215d299707c57cea5120976e/mmocr/utils/polygon_utils.py#L80-L96) + + - [ ] Unit tests + + > Unit tests for each module are required. [Example](https://github.com/open-mmlab/mmocr/blob/76637a290507f151215d299707c57cea5120976e/tests/test_utils/test_polygon_utils.py#L97-L106) + + - [ ] Code polishing + + > Refactor your code according to reviewer's comment. + + - [ ] Metafile.yml + + > It will be parsed by MIM and Inferencer. [Example](https://github.com/open-mmlab/mmocr/blob/1.x/configs/textdet/dbnet/metafile.yml) + +- [ ] Move your modules into the core package following the codebase's file hierarchy structure. + + > In particular, you may have to refactor this README into a standard one. [Example](/configs/textdet/dbnet/README.md) + +- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure. diff --git a/projects/Donut/configs/_base_/default_runtime.py b/projects/Donut/configs/_base_/default_runtime.py new file mode 100644 index 000000000..c73c4e79d --- /dev/null +++ b/projects/Donut/configs/_base_/default_runtime.py @@ -0,0 +1,40 @@ +default_scope = 'mmocr' +env_cfg = dict( + cudnn_benchmark=False, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) +randomness = dict(seed=None) + +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=100), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='TokenCheckpointHook', interval=1), + sampler_seed=dict(type='DistSamplerSeedHook'), + sync_buffer=dict(type='SyncBuffersHook'), + visualization=dict( + type='VisualizationHook', + interval=1, + enable=False, + show=False, + draw_gt=False, + draw_pred=False), +) + +# Logging +log_level = 'INFO' +log_processor = dict(type='LogProcessor', window_size=10, by_epoch=True) + +load_from = None +resume = False + +vis_backends = [ + dict(type='LocalVisBackend'), + dict(type='TensorboardVisBackend') +] +visualizer = dict( + type='KIELocalVisualizer', + name='visualizer', + vis_backends=vis_backends, + is_openset=False) diff --git a/projects/Donut/configs/_base_/schedules/schedule_adam_fp16.py b/projects/Donut/configs/_base_/schedules/schedule_adam_fp16.py new file mode 100644 index 000000000..27a45405d --- /dev/null +++ b/projects/Donut/configs/_base_/schedules/schedule_adam_fp16.py @@ -0,0 +1,22 @@ +# optimizer +optim_wrapper = dict( + type='AmpOptimWrapper', + dtype='float16', + optimizer=dict(type='Adam', lr=3e-5, weight_decay=0.0001)) +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=30, val_interval=1) +val_cfg = dict(type='ValLoop', fp16=True) +test_cfg = dict(type='TestLoop', fp16=True) +# learning rate +param_scheduler = [ + # warm up learning rate scheduler + dict( + type='LinearLR', + start_factor=3e-5, + by_epoch=True, + begin=0, + end=3, + # update by iter + convert_to_iter_based=True), + # main learning rate scheduler + dict(type='CosineAnnealingLR', by_epoch=True, begin=3, end=30) +] diff --git a/projects/Donut/configs/donut_cord_30e.py b/projects/Donut/configs/donut_cord_30e.py new file mode 100644 index 000000000..00a811e2c --- /dev/null +++ b/projects/Donut/configs/donut_cord_30e.py @@ -0,0 +1,108 @@ +_base_ = [ + '_base_/default_runtime.py', + '_base_/schedules/schedule_adam_fp16.py', +] + +data_root = 'datasets/cord-v2' +task_name = 'cord-v2' + +custom_imports = dict(imports=['donut'], allow_failed_imports=False) + +# dictionary = dict( +# type='Dictionary', +# dict_file='{{ fileDirname }}/../../../dicts/english_digits_symbols.txt', +# with_padding=True, +# with_unknown=True, +# same_start_end=True, +# with_start=True, +# with_end=True) + +model = dict( + type='Donut', + data_preprocessor=dict( + type='DonutDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375]), + encoder=dict( + type='SwinEncoder', + input_size=[1280, 960], + align_long_axis=False, + window_size=10, + encoder_layer=[2, 2, 14, 2], + init_cfg=dict( + type='Pretrained', checkpoint='data/donut_base_encoder.pth')), + decoder=dict( + type='BARTDecoder', + max_position_embeddings=None, + task_start_token=f'', + prompt_end_token=f'', + decoder_layer=4, + tokenizer_cfg=dict( + type='XLMRobertaTokenizer', + checkpoint='naver-clova-ix/donut-base'), + init_cfg=dict( + type='Pretrained', checkpoint='data/donut_base_decoder.pth')), + sort_json_key=False, +) + +train_pipeline = [ + dict(type='LoadImageFromFile', ignore_empty=True, min_size=2), + dict(type='LoadJsonAnnotations', with_bbox=False, with_label=False), + dict(type='TorchVisionWrapper', op='Resize', size=960, max_size=1280), + dict(type='RandomPad', input_size=[1280, 960], random_padding=True), + dict( + type='PackKIEInputs', + meta_keys=('img_path', 'ori_shape', 'img_shape', 'parses_json')) +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='TorchVisionWrapper', op='Resize', size=960, max_size=1280), + dict(type='RandomPad', input_size=[1280, 960], random_padding=False), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict(type='LoadJsonAnnotations', with_bbox=False, with_label=False), + dict( + type='PackKIEInputs', + meta_keys=('img_path', 'ori_shape', 'img_shape', 'parses_json')) +] + +# dataset settings +train_dataset = dict( + type='CORDDataset', + data_root=data_root, + split_name='train', + pipeline=train_pipeline) +val_dataset = dict( + type='CORDDataset', + data_root=data_root, + split_name='validation', + pipeline=test_pipeline) +test_dataset = dict( + type='CORDDataset', + data_root=data_root, + split_name='test', + pipeline=test_pipeline) + +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=train_dataset) + +test_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=test_dataset) + +val_dataloader = test_dataloader + +val_evaluator = dict(type='DonutValEvaluator', key='parses') +test_evaluator = dict(type='JSONParseEvaluator', key='parses_json') + +randomness = dict(seed=2022) +find_unused_parameters = True diff --git a/projects/Donut/donut/__init__.py b/projects/Donut/donut/__init__.py new file mode 100644 index 000000000..ab304cbed --- /dev/null +++ b/projects/Donut/donut/__init__.py @@ -0,0 +1,4 @@ +from .datasets import * # NOQA +from .engine import * # NOQA +from .evaluation import * # NOQA +from .model import * # NOQA diff --git a/projects/Donut/donut/datasets/__init__.py b/projects/Donut/donut/datasets/__init__.py new file mode 100644 index 000000000..69fca0433 --- /dev/null +++ b/projects/Donut/donut/datasets/__init__.py @@ -0,0 +1,4 @@ +from .cord_dataset import CORDDataset +from .transforms import * # NOQA + +__all__ = ['CORDDataset'] diff --git a/projects/Donut/donut/datasets/cord_dataset.py b/projects/Donut/donut/datasets/cord_dataset.py new file mode 100644 index 000000000..da1ecb50e --- /dev/null +++ b/projects/Donut/donut/datasets/cord_dataset.py @@ -0,0 +1,146 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os.path as osp +from typing import Any, Callable, List, Optional, Sequence, Union + +from mmengine.dataset import BaseDataset + +from mmocr.registry import DATASETS, TASK_UTILS + +SPECIAL_TOKENS = [] + + +@DATASETS.register_module() +class CORDDataset(BaseDataset): + r"""CORDDataset for KIE. + + The annotation format can be jsonl. It should be a list of dicts. + + The annotation formats are shown as follows. + - jsonl format + .. code-block:: none + + ``{"filename": "test_img1.jpg", "ground_truth": {"OpenMMLab"}}`` + ``{"filename": "test_img2.jpg", "ground_truth": {"MMOCR"}}`` + + Args: + ann_file (str): Annotation file path. Defaults to ''. + backend_args (dict, optional): Arguments to instantiate the + prefix of uri corresponding backend. Defaults to None. + metainfo (dict, optional): Meta information for dataset, such as class + information. Defaults to None. + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to ''. + data_prefix (dict): Prefix for training data. Defaults to + ``dict(img_path='')``. + filter_cfg (dict, optional): Config for filter data. Defaults to None. + indices (int or Sequence[int], optional): Support using first few data + in annotation file to facilitate training/testing on a smaller + dataset. Defaults to None which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. Defaults + to True. + pipeline (list, optional): Processing pipeline. Defaults to []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Defaults to False. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``RecogTextDataset`` can skip load + annotations to save time by set ``lazy_init=False``. Defaults to + False. + max_refetch (int, optional): If ``RecogTextDataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Defaults to 1000. + """ + + def __init__(self, + split_name: str = '', + backend_args=None, + parser_cfg: Optional[dict] = dict( + type='LineJsonParser', keys=['file_name', + 'ground_truth']), + metainfo: Optional[dict] = None, + data_root: Optional[str] = '', + data_prefix: dict = dict(img_path=''), + filter_cfg: Optional[dict] = None, + indices: Optional[Union[int, Sequence[int]]] = None, + serialize_data: bool = True, + pipeline: List[Union[dict, Callable]] = [], + test_mode: bool = False, + lazy_init: bool = False, + max_refetch: int = 1000) -> None: + + self.parser = TASK_UTILS.build(parser_cfg) + self.backend_args = backend_args + self.split_name = split_name + + super().__init__( + ann_file='', + metainfo=metainfo, + data_root=data_root, + data_prefix=data_prefix, + filter_cfg=filter_cfg, + indices=indices, + serialize_data=serialize_data, + pipeline=pipeline, + test_mode=test_mode, + lazy_init=lazy_init, + max_refetch=max_refetch) + + def load_data_list(self) -> List[dict]: + """Load annotations from an annotation file named as ``self.ann_file`` + + Returns: + List[dict]: A list of annotation. + """ + data_list = [] + # dataset = load_dataset(self.data_root, split=self.split_name) + metadata_path = osp.join(self.data_root, self.split_name, + 'metadata.jsonl') + assert osp.exists(metadata_path), metadata_path + with open(metadata_path) as f: + metadata = f.read().strip().split('\n') + for sample_data in metadata: + sample = json.loads(sample_data) + img_path = osp.join(self.data_root, self.split_name, + sample['file_name']) + gt = json.loads(sample['ground_truth']) + + if 'gt_parse' in gt: + gt_jsons = gt.pop('gt_parse') + gt['parses_json'] = gt_jsons + else: + gt['parses_json'] = gt.pop('gt_parses') + + if self.split_name == 'train': + global SPECIAL_TOKENS + SPECIAL_TOKENS += self.search_special_tokens(gt['parses_json']) + + if isinstance(gt, list): + instances = gt + else: + instances = [gt] + data_list.append({'img_path': img_path, 'instances': instances}) + return data_list + + def search_special_tokens(self, obj: Any, sort_json_key: bool = True): + """Convert an ordered JSON object into a token sequence.""" + special_tokens = [] + if type(obj) == dict: + if len(obj) == 1 and 'text_sequence' in obj: + pass + else: + if sort_json_key: + keys = sorted(obj.keys(), reverse=True) + else: + keys = obj.keys() + for k in keys: + special_tokens += [fr'', fr''] + special_tokens += self.search_special_tokens(obj[k]) + elif type(obj) == list: + for item in obj: + special_tokens += self.search_special_tokens( + item, sort_json_key) + return special_tokens diff --git a/projects/Donut/donut/datasets/transforms/__init__.py b/projects/Donut/donut/datasets/transforms/__init__.py new file mode 100644 index 000000000..69eeaefce --- /dev/null +++ b/projects/Donut/donut/datasets/transforms/__init__.py @@ -0,0 +1,4 @@ +from .loading import LoadJsonAnnotations +from .random_transform import RandomPad + +__all__ = ['LoadJsonAnnotations', 'RandomPad'] diff --git a/projects/Donut/donut/datasets/transforms/loading.py b/projects/Donut/donut/datasets/transforms/loading.py new file mode 100644 index 000000000..007c06f4c --- /dev/null +++ b/projects/Donut/donut/datasets/transforms/loading.py @@ -0,0 +1,163 @@ +import copy +from typing import Optional + +import numpy as np +from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations + +from mmocr.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class LoadJsonAnnotations(MMCV_LoadAnnotations): + """Load and process the ``instances`` annotation provided by dataset. + + The annotation format is as the following: + + .. code-block:: python + + { + # A nested list of 4 numbers representing the bounding box of the + # instance, in (x1, y1, x2, y2) order. + 'bbox': np.array([[x1, y1, x2, y2], [x1, y1, x2, y2], ...], + dtype=np.int32), + + # Labels of boxes. Shape is (N,). + 'bbox_labels': np.array([0, 2, ...], dtype=np.int32), + + # Labels of edges. Shape (N, N). + 'edge_labels': np.array([0, 2, ...], dtype=np.int32), + + # List of texts. + "texts": ['text1', 'text2', ...], + } + + After this module, the annotation has been changed to the format below: + + .. code-block:: python + + { + # In (x1, y1, x2, y2) order, float type. N is the number of bboxes + # in np.float32 + 'gt_bboxes': np.ndarray(N, 4), + # In np.int64 type. + 'gt_bboxes_labels': np.ndarray(N, ), + # In np.int32 type. + 'gt_edges_labels': np.ndarray(N, N), + # In list[str] + 'gt_texts': list[str], + # tuple(int) + 'ori_shape': (H, W) + } + + Required Keys: + + - bboxes + - bbox_labels + - edge_labels + - texts + + Added Keys: + + - gt_bboxes (np.float32) + - gt_bboxes_labels (np.int64) + - gt_edges_labels (np.int64) + - gt_texts (list[str]) + - ori_shape (tuple[int]) + + Args: + with_bbox (bool): Whether to parse and load the bbox annotation. + Defaults to True. + with_label (bool): Whether to parse and load the label annotation. + Defaults to True. + with_text (bool): Whether to parse and load the text annotation. + Defaults to True. + directed (bool): Whether build edges as a directed graph. + Defaults to False. + key_node_idx (int, optional): Key node label, used to mask out edges + that are not connected from key nodes to value nodes. It has to be + specified together with ``value_node_idx``. Defaults to None. + value_node_idx (int, optional): Value node label, used to mask out + edges that are not connected from key nodes to value nodes. It has + to be specified together with ``key_node_idx``. Defaults to None. + """ + + def __init__(self, + with_bbox: bool = True, + with_label: bool = True, + with_text: bool = True, + directed: bool = False, + key_node_idx: Optional[int] = None, + value_node_idx: Optional[int] = None, + **kwargs) -> None: + super().__init__(with_bbox=with_bbox, with_label=with_label, **kwargs) + self.with_text = with_text + self.directed = directed + if key_node_idx is not None or value_node_idx is not None: + assert key_node_idx is not None and value_node_idx is not None + self.key_node_idx = key_node_idx + self.value_node_idx = value_node_idx + + def _load_parse(self, results: dict) -> None: + """Private function to load text annotations. + + Args: + results (dict): Result dict from :obj:``OCRDataset``. + """ + gt_parse = [] + for instance in results['instances']: + gt_parse.append(instance['parses_json']) + results['parses_json'] = gt_parse + + def _load_labels(self, results: dict) -> None: + """Private function to load label annotations. + + Args: + results (dict): Result dict from :obj:``WildReceiptDataset``. + """ + bbox_labels = [] + edge_labels = [] + for instance in results['instances']: + bbox_labels.append(instance['bbox_label']) + edge_labels.append(instance['edge_label']) + + bbox_labels = np.array(bbox_labels, np.int32) + edge_labels = np.array(edge_labels) + edge_labels = (edge_labels[:, None] == edge_labels[None, :]).astype( + np.int32) + + if self.directed: + edge_labels = (edge_labels & bbox_labels == 1).astype(np.int32) + + if hasattr(self, 'key_node_idx'): + key_nodes_mask = bbox_labels == self.key_node_idx + value_nodes_mask = bbox_labels == self.value_node_idx + key2value_mask = key_nodes_mask[:, + None] * value_nodes_mask[None, :] + edge_labels[~key2value_mask] = -1 + + np.fill_diagonal(edge_labels, -1) + + results['gt_edges_labels'] = edge_labels.astype(np.int64) + results['gt_bboxes_labels'] = bbox_labels.astype(np.int64) + + def transform(self, results: dict) -> dict: + """Function to load multiple types annotations. + + Args: + results (dict): Result dict from :obj:``OCRDataset``. + + Returns: + dict: The dict contains loaded bounding box, label polygon and + text annotations. + """ + if 'ori_shape' not in results: + results['ori_shape'] = copy.deepcopy(results['img_shape']) + results = super().transform(results) + self._load_parse(results) + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(with_bbox={self.with_bbox}, ' + repr_str += f'with_label={self.with_label})' + return repr_str diff --git a/projects/Donut/donut/datasets/transforms/random_transform.py b/projects/Donut/donut/datasets/transforms/random_transform.py new file mode 100644 index 000000000..38c810b1a --- /dev/null +++ b/projects/Donut/donut/datasets/transforms/random_transform.py @@ -0,0 +1,92 @@ +from typing import Dict, List + +import mmcv +import numpy as np +from mmcv.transforms.base import BaseTransform + +from mmocr.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class RandomPad(BaseTransform): + """Only pad the image's width. + + Required Keys: + + - img + + Modified Keys: + + - img + - img_shape + + Added Keys: + + - pad_shape + - pad_fixed_size + - pad_size_divisor + - valid_ratio + + Args: + width (int): Target width of padded image. Defaults to None. + pad_cfg (dict): Config to construct the Resize transform. Refer to + ``Pad`` for detail. Defaults to ``dict(type='Pad')``. + """ + + def __init__(self, + input_size: List[int], + random_padding: bool = True, + fill=0, + pad_cfg: dict = dict(type='Pad')) -> None: + super().__init__() + height, width = input_size + assert isinstance(width, int) + assert isinstance(height, int) + self.width = width + self.height = height + self.random_padding = random_padding + self.fill = fill + self.pad_cfg = pad_cfg + _pad_cfg = self.pad_cfg.copy() + _pad_cfg.update(dict(size=0)) + self.pad = TRANSFORMS.build(_pad_cfg) + + def transform(self, results: Dict) -> Dict: + """Call function to pad images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Updated result dict. + """ + ori_height, ori_width = results['img'].shape[:2] + delta_width = self.width - ori_width + delta_height = self.height - ori_height + if self.random_padding: + pad_width = np.random.randint(low=0, high=delta_width + 1) + pad_height = np.random.randint(low=0, high=delta_height + 1) + else: + pad_width = delta_width // 2 + pad_height = delta_height // 2 + padding = ( + pad_width, + pad_height, + delta_width - pad_width, + delta_height - pad_height, + ) + + results['img'] = mmcv.impad( + results['img'], + padding=padding, + pad_val=self.fill, + padding_mode='constant') + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(width={self.width}, ' + repr_str += f'(height={self.height}, ' + repr_str += f'(random_padding={self.random_padding}, ' + repr_str += f'pad_cfg={self.pad_cfg})' + return repr_str diff --git a/projects/Donut/donut/engine/__init__.py b/projects/Donut/donut/engine/__init__.py new file mode 100644 index 000000000..d09239685 --- /dev/null +++ b/projects/Donut/donut/engine/__init__.py @@ -0,0 +1 @@ +from .hooks import * # NOQA diff --git a/projects/Donut/donut/engine/hooks/__init__.py b/projects/Donut/donut/engine/hooks/__init__.py new file mode 100644 index 000000000..c34854905 --- /dev/null +++ b/projects/Donut/donut/engine/hooks/__init__.py @@ -0,0 +1,3 @@ +from .token_checkpoint_hook import TokenCheckpointHook + +__all__ = ['TokenCheckpointHook'] diff --git a/projects/Donut/donut/engine/hooks/token_checkpoint_hook.py b/projects/Donut/donut/engine/hooks/token_checkpoint_hook.py new file mode 100644 index 000000000..ab6718047 --- /dev/null +++ b/projects/Donut/donut/engine/hooks/token_checkpoint_hook.py @@ -0,0 +1,30 @@ +import json +import os + +from mmengine.hooks import CheckpointHook +from mmengine.model import MMDistributedDataParallel + +from mmocr.registry import HOOKS + + +@HOOKS.register_module() +class TokenCheckpointHook(CheckpointHook): + """""" + + def before_train(self, runner): + """save tokenizer.""" + super().before_train(runner=runner) + if isinstance(runner.model, MMDistributedDataParallel): + tokenizer = runner.model.module.decoder.tokenizer + tokenizer.save_vocabulary(self.out_dir) + added_vocab = tokenizer.get_added_vocab() + with open(os.path.join(self.out_dir, 'added_tokens.json'), + 'w') as f: + json.dump(added_vocab, f) + else: + tokenizer = runner.model.decoder.tokenizer + tokenizer.save_vocabulary(self.out_dir) + added_vocab = tokenizer.get_added_vocab() + with open(os.path.join(self.out_dir, 'added_tokens.json'), + 'w') as f: + json.dump(added_vocab, f) diff --git a/projects/Donut/donut/evaluation/__init__.py b/projects/Donut/donut/evaluation/__init__.py new file mode 100644 index 000000000..e9f2df5e3 --- /dev/null +++ b/projects/Donut/donut/evaluation/__init__.py @@ -0,0 +1 @@ +from .metrics import * # NOQA diff --git a/projects/Donut/donut/evaluation/metrics/__init__.py b/projects/Donut/donut/evaluation/metrics/__init__.py new file mode 100644 index 000000000..303f3b6a8 --- /dev/null +++ b/projects/Donut/donut/evaluation/metrics/__init__.py @@ -0,0 +1,4 @@ +from .ted_metric import JSONParseEvaluator +from .val_metric import DonutValEvaluator + +__all__ = ['JSONParseEvaluator', 'DonutValEvaluator'] diff --git a/projects/Donut/donut/evaluation/metrics/ted_metric.py b/projects/Donut/donut/evaluation/metrics/ted_metric.py new file mode 100644 index 000000000..90373b239 --- /dev/null +++ b/projects/Donut/donut/evaluation/metrics/ted_metric.py @@ -0,0 +1,282 @@ +from typing import Any, Dict, List, Optional, Sequence, Union + +import numpy as np +import zss +from mmengine.evaluator import BaseMetric +from nltk import edit_distance +from zss import Node + +from mmocr.registry import METRICS + + +@METRICS.register_module() +class JSONParseEvaluator(BaseMetric): + """Calculate n-TED(Normalized Tree Edit Distance) based accuracy and F1 + accuracy score.""" + default_prefix: Optional[str] = 'kie' + + def __init__(self, + key: str = 'labels', + mode: Union[str, Sequence[str]] = 'micro', + cared_classes: Sequence[int] = [], + ignored_classes: Sequence[int] = [], + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device, prefix) + assert isinstance(cared_classes, (list, tuple)) + assert isinstance(ignored_classes, (list, tuple)) + assert isinstance(mode, (list, str)) + assert not (len(cared_classes) > 0 and len(ignored_classes) > 0), \ + 'cared_classes and ignored_classes cannot be both non-empty' + + if isinstance(mode, str): + mode = [mode] + assert set(mode).issubset({'micro', 'macro'}) + self.mode = mode + self.key = key + + def process(self, data_batch: Sequence[Dict], + data_samples: Sequence[Dict]) -> None: + """Process one batch of data_samples. The processed results should be + stored in ``self.results``, which will be used to compute the metrics + when all batches have been processed. + + Args: + data_batch (Sequence[Dict]): A batch of gts. + data_samples (Sequence[Dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + pred_labels = data_sample.get('pred_instances').get(self.key) + gt_labels = data_sample.get('gt_instances').get(self.key) + print(pred_labels) + print(gt_labels) + result = dict(pred_labels=pred_labels, gt_labels=gt_labels) + self.results.append(result) + + def compute_metrics(self, results: Sequence[Dict]) -> Dict: + """Compute the metrics from processed results. + + Args: + results (list[Dict]): The processed results of each batch. + + Returns: + dict[str, float]: The f1 scores. The keys are the names of the + metrics, and the values are corresponding results. Possible + keys are 'micro_f1' and 'macro_f1'. + """ + + preds = [] + gts = [] + scores = [] + for result in results: + pred = result['pred_labels'] + gt = result['gt_labels'] + preds.append(pred) + gts.append(gt) + scores.append(self.cal_acc(pred, gt)) + + result = {} + f1_score = self.cal_f1(preds, gts) + result['ted_accuracy'] = np.mean(scores) + result['f1_accuracy'] = f1_score + return result + + @staticmethod + def flatten(data: dict): + """ + Convert Dictionary into Non-nested Dictionary + Example: + input(dict) + { + "menu": [ + {"name" : ["cake"], "count" : ["2"]}, + {"name" : ["juice"], "count" : ["1"]}, + ] + } + output(list) + [ + ("menu.name", "cake"), + ("menu.count", "2"), + ("menu.name", "juice"), + ("menu.count", "1"), + ] + """ + flatten_data = list() + + def _flatten(value, key=''): + if type(value) is dict: + for child_key, child_value in value.items(): + _flatten(child_value, + f'{key}.{child_key}' if key else child_key) + elif type(value) is list: + for value_item in value: + _flatten(value_item, key) + else: + flatten_data.append((key, value)) + + _flatten(data) + return flatten_data + + @staticmethod + def update_cost(node1: Node, node2: Node): + """Update cost for tree edit distance. + + If both are leaf node, calculate string edit distance between two + labels (special token '' will be ignored). If one of them is leaf + node, cost is length of string in leaf node + 1. If neither are leaf + node, cost is 0 if label1 is same with label2 otherwise 1 + """ + label1 = node1.label + label2 = node2.label + label1_leaf = '' in label1 + label2_leaf = '' in label2 + if label1_leaf and label2_leaf: + return edit_distance( + label1.replace('', ''), label2.replace('', '')) + elif (not label1_leaf) and label2_leaf: + return 1 + len(label2.replace('', '')) + elif label1_leaf and (not label2_leaf): + return 1 + len(label1.replace('', '')) + else: + return int(label1 != label2) + + @staticmethod + def insert_and_remove_cost(node: Node): + """Insert and remove cost for tree edit distance. + + If leaf node, cost is length of label name. Otherwise, 1 + """ + label = node.label + if '' in label: + return len(label.replace('', '')) + else: + return 1 + + def normalize_dict(self, data: Union[Dict, List, Any]): + """Sort by value, while iterate over element if data is list.""" + if not data: + return {} + + if isinstance(data, dict): + new_data = dict() + for key in sorted(data.keys(), key=lambda k: (len(k), k)): + value = self.normalize_dict(data[key]) + if value: + if not isinstance(value, list): + value = [value] + new_data[key] = value + + elif isinstance(data, list): + if all(isinstance(item, dict) for item in data): + new_data = [] + for item in data: + item = self.normalize_dict(item) + if item: + new_data.append(item) + else: + new_data = [ + str(item).strip() for item in data + if type(item) in {str, int, float} and str(item).strip() + ] + else: + new_data = [str(data).strip()] + + return new_data + + def cal_f1(self, preds: List[dict], answers: List[dict]): + """Calculate global F1 accuracy score (field-level, micro-averaged) by + counting all true positives, false negatives and false positives.""" + total_tp, total_fn_or_fp = 0, 0 + for pred, answer in zip(preds, answers): + pred, answer = self.flatten( + self.normalize_dict(pred)), self.flatten( + self.normalize_dict(answer)) + for field in pred: + if field in answer: + total_tp += 1 + answer.remove(field) + else: + total_fn_or_fp += 1 + total_fn_or_fp += len(answer) + return total_tp / (total_tp + total_fn_or_fp / 2) + + def construct_tree_from_dict(self, + data: Union[Dict, List], + node_name: str = None): + """Convert Dictionary into Tree. + + Example: + input(dict) + + { + "menu": [ + {"name" : ["cake"], "count" : ["2"]}, + {"name" : ["juice"], "count" : ["1"]}, + ] + } + + output(tree) + + | + menu + / \ + + / | | \ + name count name count + / | | \ + cake 2 juice 1 + """ + if node_name is None: + node_name = '' + + node = Node(node_name) + + if isinstance(data, dict): + for key, value in data.items(): + kid_node = self.construct_tree_from_dict(value, key) + node.addkid(kid_node) + elif isinstance(data, list): + if all(isinstance(item, dict) for item in data): + for item in data: + kid_node = self.construct_tree_from_dict( + item, + '', + ) + node.addkid(kid_node) + else: + for item in data: + node.addkid(Node(f'{item}')) + else: + raise Exception(data, node_name) + return node + + def cal_acc(self, pred: dict, answer: dict): + """Calculate normalized tree edit distance(nTED) based accuracy. 1) + Construct tree from dict, 2) Get tree distance with + insert/remove/update cost, 3) Divide distance with GT tree size (i.e., + nTED), + + 4) Calculate nTED based accuracy. (= max(1 - nTED, 0 ). + """ + pred = self.construct_tree_from_dict(self.normalize_dict(pred)) + answer = self.construct_tree_from_dict(self.normalize_dict(answer)) + return max( + 0, + 1 - (zss.distance( + pred, + answer, + get_children=zss.Node.get_children, + insert_cost=self.insert_and_remove_cost, + remove_cost=self.insert_and_remove_cost, + update_cost=self.update_cost, + return_operations=False, + ) / zss.distance( + self.construct_tree_from_dict(self.normalize_dict({})), + answer, + get_children=zss.Node.get_children, + insert_cost=self.insert_and_remove_cost, + remove_cost=self.insert_and_remove_cost, + update_cost=self.update_cost, + return_operations=False, + )), + ) diff --git a/projects/Donut/donut/evaluation/metrics/val_metric.py b/projects/Donut/donut/evaluation/metrics/val_metric.py new file mode 100644 index 000000000..7deef4c24 --- /dev/null +++ b/projects/Donut/donut/evaluation/metrics/val_metric.py @@ -0,0 +1,62 @@ +import re +from typing import Dict, Optional, Sequence + +import numpy as np +from mmengine.evaluator import BaseMetric +from nltk import edit_distance + +from mmocr.registry import METRICS + + +@METRICS.register_module() +class DonutValEvaluator(BaseMetric): + default_prefix: Optional[str] = '' + + def __init__(self, + key: str = 'parses', + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device, prefix) + self.key = key + + def process(self, data_batch: Sequence[Dict], + data_samples: Sequence[Dict]) -> None: + """Process one batch of data_samples. The processed results should be + stored in ``self.results``, which will be used to compute the metrics + when all batches have been processed. + + Args: + data_batch (Sequence[Dict]): A batch of gts. + data_samples (Sequence[Dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + pred_parses = data_sample.get('pred_instances').get(self.key)[0] + gt_parses = data_sample.get('gt_instances').get(self.key)[0] + + result = dict(pred_labels=pred_parses, gt_labels=gt_parses) + self.results.append(result) + + def compute_metrics(self, results: Sequence[Dict]) -> Dict: + """Compute the metrics from processed results. + + Args: + results (list[Dict]): The processed results of each batch. + + Returns: + dict[str, float]: The f1 scores. The keys are the names of the + metrics, and the values are corresponding results. Possible + keys are 'micro_f1' and 'macro_f1'. + """ + + scores = [] + for result in results: + pred = result['pred_labels'] + pred = re.sub(r'(?:(?<=>) | (?= is used for representing a list in a JSON + self.add_special_tokens(['']) + pad_token_id = self.tokenizer.pad_token_id + self.model.model.decoder.embed_tokens.padding_idx = pad_token_id + prepare_inputs = self.prepare_inputs_for_inference + self.model.prepare_inputs_for_generation = prepare_inputs + + @property + def tokenizer(self): + return self._tokenizer + + @property + def prompt_end_token_id(self): + return self.tokenizer.convert_tokens_to_ids(self.prompt_end_token) + + def init_weights(self): + super().init_weights() + + # weight init with asian-bart + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + bart_state_dict = MBartForCausalLM.from_pretrained( + 'hyunwoongko/asian-bart-ecjk').state_dict() + else: + bart_state_dict = OrderedDict() + model_state_dict = torch.load(self.init_cfg['checkpoint']) + for k, v in model_state_dict.items(): + if k.startswith('model.'): + bart_state_dict[k[len('model.'):]] = v + + new_bart_state_dict = self.model.state_dict() + for x in new_bart_state_dict: + if x.endswith('embed_positions.weight' + ) and self.max_position_embeddings != 1024: + new_bart_state_dict[x] = torch.nn.Parameter( + self.resize_bart_abs_pos_emb( + bart_state_dict[x], + self.max_position_embeddings + 2, + )) + elif x.endswith('embed_tokens.weight') or x.endswith( + 'lm_head.weight'): + new_bart_state_dict[x] = bart_state_dict[x][:len(self.tokenizer + ), :] + else: + new_bart_state_dict[x] = bart_state_dict[x] + self.model.load_state_dict(new_bart_state_dict) + + def add_special_tokens(self, list_of_tokens: List[str]): + """Add special tokens to tokenizer and resize the token embeddings.""" + newly_added_num = self.tokenizer.add_special_tokens( + {'additional_special_tokens': sorted(set(list_of_tokens))}) + if newly_added_num > 0: + self.model.resize_token_embeddings(len(self.tokenizer)) + + def prepare_inputs_for_inference(self, + input_ids: torch.Tensor, + encoder_outputs: torch.Tensor, + past_key_values=None, + past=None, + use_cache: bool = None, + attention_mask: torch.Tensor = None): + """ + Args: + input_ids: (batch_size, sequence_lenth) + Returns: + input_ids: (batch_size, sequence_length) + attention_mask: (batch_size, sequence_length) + encoder_hidden_states: (batch_size, sequence_length, embedding_dim) + """ + # for compatibility with transformers==4.11.x + if past is not None: + past_key_values = past + attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long() + if past_key_values is not None: + input_ids = input_ids[:, -1:] + output = { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'encoder_hidden_states': encoder_outputs.last_hidden_state, + } + return output + + def extract_feat(self, + input_ids, + encoder_hidden_states: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[torch.Tensor] = None, + use_cache: bool = None, + output_attentions: Optional[torch.Tensor] = None, + output_hidden_states: Optional[torch.Tensor] = None, + return_dict: bool = True, + data_samples=None): + """""" + if output_attentions is None: + output_attentions = self.model.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.model.config.output_hidden_states) + if return_dict is None: + return_dict = self.model.config.use_return_dict + + outputs = self.model.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.model.lm_head(outputs[0]) + outputs['logits'] = logits + return outputs + + def loss(self, + input_ids, + encoder_hidden_states: Optional[torch.Tensor], + labels, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[torch.Tensor] = None, + use_cache: bool = None, + output_attentions: Optional[torch.Tensor] = None, + output_hidden_states: Optional[torch.Tensor] = None, + data_samples=None): + """A forward function to get cross attentions and utilize `generate` + function. + + Source: + https://github.com/huggingface/transformers/blob/v4.11.3/src/transformers/models/mbart/modeling_mbart.py#L1669-L1810 + + Args: + input_ids: (batch_size, sequence_length) + attention_mask: (batch_size, sequence_length) + encoder_hidden_states: (batch_size, sequence_length, hidden_size) + + Returns: + loss: (1, ) + logits: (batch_size, sequence_length, hidden_dim) + hidden_states: (batch_size, sequence_length, hidden_size) + decoder_attentions: (batch_size, num_heads, sequence_length, + sequence_length) + cross_attentions: (batch_size, num_heads, sequence_length, + sequence_length) + """ + outputs = self.extract_feat( + input_ids, + encoder_hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + output_hidden_states, + data_samples=data_samples) + logits = outputs['logits'] + + loss = None + loss_fct = nn.CrossEntropyLoss(ignore_index=-100) + loss = loss_fct( + logits.view(-1, self.model.config.vocab_size), labels.view(-1)) + + return {'loss': loss} + + def forward(self, + input_ids, + encoder_hidden_states: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[torch.Tensor] = None, + use_cache: bool = None, + output_attentions: Optional[torch.Tensor] = None, + output_hidden_states: Optional[torch.Tensor] = None, + return_dict=None, + data_samples=None): + """A forward function to get cross attentions and utilize `generate` + function. + + Source: + https://github.com/huggingface/transformers/blob/v4.11.3/src/transformers/models/mbart/modeling_mbart.py#L1669-L1810 + + Args: + input_ids: (batch_size, sequence_length) + attention_mask: (batch_size, sequence_length) + encoder_hidden_states: (batch_size, sequence_length, hidden_size) + + Returns: + loss: (1, ) + logits: (batch_size, sequence_length, hidden_dim) + hidden_states: (batch_size, sequence_length, hidden_size) + decoder_attentions: (batch_size, num_heads, + sequence_length, sequence_length) + cross_attentions: (batch_size, num_heads, + sequence_length, sequence_length) + """ + outputs = self.extract_feat( + input_ids, + encoder_hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + output_hidden_states, + data_samples=data_samples) + + return ModelOutput( + loss=None, + logits=outputs.logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + decoder_attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + @staticmethod + def resize_bart_abs_pos_emb(weight: torch.Tensor, + max_length: int) -> torch.Tensor: + """Resize position embeddings Truncate if sequence length of Bart + backbone is greater than given max_length, else interpolate to + max_length.""" + if weight.shape[0] > max_length: + weight = weight[:max_length, ...] + else: + weight = ( + F.interpolate( + weight.permute(1, 0).unsqueeze(0), + size=max_length, + mode='linear', + align_corners=False, + ).squeeze(0).permute(1, 0)) + return weight diff --git a/projects/Donut/donut/model/donut.py b/projects/Donut/donut/model/donut.py new file mode 100644 index 000000000..9e7144319 --- /dev/null +++ b/projects/Donut/donut/model/donut.py @@ -0,0 +1,462 @@ +import random +import re +from typing import Any, Union + +import torch +from mmengine.model import BaseModel +from mmengine.structures import InstanceData +from torch.nn.utils.rnn import pad_sequence +from transformers.file_utils import ModelOutput + +from mmocr.registry import MODELS +from ..datasets.cord_dataset import SPECIAL_TOKENS + + +@MODELS.register_module() +class Donut(BaseModel): + + def __init__(self, + data_preprocessor=None, + encoder=dict( + type='SwinEncoder', + input_size=[1280, 960], + align_long_axis=False, + window_size=10, + encoder_layer=[2, 2, 14, 2], + name_or_path=''), + decoder=dict( + type='BARTDecoder', + max_position_embeddings=None, + task_start_token='', + prompt_end_token=None, + decoder_layer=4, + name_or_path=''), + max_length=768, + ignore_mismatched_sizes=True, + sort_json_key: bool = True, + ignore_id: int = -100, + init_cfg=dict()): + super().__init__(data_preprocessor, init_cfg) + + self.max_length = max_length + self.ignore_mismatched_sizes = ignore_mismatched_sizes + self.sort_json_key = sort_json_key + self.ignore_id = ignore_id + + self.encoder = MODELS.build(encoder) + + decoder['max_position_embeddings'] = max_length if decoder[ + 'max_position_embeddings'] is None else decoder[ + 'max_position_embeddings'] + self.decoder = MODELS.build(decoder) + + def init_weights(self): + super().init_weights() + self.decoder.add_special_tokens(SPECIAL_TOKENS) + self.decoder.add_special_tokens( + [self.decoder.task_start_token, self.decoder.prompt_end_token]) + return + + def get_input_ids_val(self, data_samples): + # input_ids + decoder_input_ids = list() + batch_prompt_end_index = list() + batch_processed_parse = list() + + for sample in data_samples: + if hasattr(sample, 'parses_json'): + assert isinstance(sample.parses_json, list) + gt_jsons = sample.parses_json + else: + print(sample.keys()) + raise KeyError + + # load json from list of json + gt_token_sequences = [] + for gt_json in gt_jsons: + gt_token = self.json2token( + gt_json, + update_special_tokens_for_json_key=False, + sort_json_key=self.sort_json_key) + gt_token_sequences.append(self.decoder.task_start_token + + gt_token + + self.decoder.tokenizer.eos_token) + # can be more than one, e.g., DocVQA Task 1 + token_index = random.randint(0, len(gt_token_sequences) - 1) + processed_parse = gt_token_sequences[token_index] + + input_ids = self.decoder.tokenizer( + processed_parse, + add_special_tokens=False, + max_length=self.max_length, + padding='max_length', + truncation=True, + return_tensors='pt', + )['input_ids'].squeeze(0) + + # return prompt end index instead of target output labels + prompt_end_index = torch.nonzero( + input_ids == self.decoder.prompt_end_token_id).sum() + batch_prompt_end_index.append(prompt_end_index) + batch_processed_parse.append(processed_parse) + + decoder_input_ids.append(input_ids[:-1]) + + sample.gt_instances['parses_json'] = [gt_jsons[token_index] + ] # [] for len check + sample.gt_instances['parses'] = [processed_parse] + + decoder_input_ids = torch.stack(decoder_input_ids, dim=0) + return decoder_input_ids, batch_prompt_end_index, batch_processed_parse + + def get_input_ids_train(self, data_samples): + # input_ids + decoder_input_ids = list() + decoder_labels = list() + + for sample in data_samples: + assert isinstance(sample.parses_json, list) + gt_jsons = sample.parses_json + + # load json from list of json + gt_token_sequences = [] + for gt_json in gt_jsons: + gt_token = self.json2token( + gt_json, + update_special_tokens_for_json_key=False, + sort_json_key=self.sort_json_key) + gt_token_sequences.append(self.decoder.task_start_token + + gt_token + + self.decoder.tokenizer.eos_token) + # can be more than one, e.g., DocVQA Task 1 + token_index = random.randint(0, len(gt_token_sequences) - 1) + processed_parse = gt_token_sequences[token_index] + + input_ids = self.decoder.tokenizer( + processed_parse, + add_special_tokens=False, + max_length=self.max_length, + padding='max_length', + truncation=True, + return_tensors='pt', + )['input_ids'].squeeze(0) + + labels = input_ids.clone() + # model doesn't need to predict pad token + labels[labels == + self.decoder.tokenizer.pad_token_id] = self.ignore_id + # model doesn't need to predict prompt (for VQA) + labels[:torch.nonzero( + labels == self.decoder.prompt_end_token_id).sum() + + 1] = self.ignore_id + decoder_labels.append(labels[1:]) + + decoder_input_ids.append(input_ids[:-1]) + sample.gt_instances['parses_json'] = [gt_jsons[token_index]] + sample.gt_instances['parses'] = [processed_parse] + + decoder_input_ids = torch.stack(decoder_input_ids, dim=0) + decoder_labels = torch.stack(decoder_labels, dim=0) + return decoder_input_ids, decoder_labels + + def test_step(self, data: Union[dict, tuple, list]) -> list: + """``BaseModel`` implements ``test_step`` the same as ``val_step``. + + Args: + data (dict or tuple or list): Data sampled from dataset. + + Returns: + list: The predictions of given data. + """ + data = self.data_preprocessor(data, False) + return self._run_forward(data, mode='test') # type: ignore + + def forward(self, + inputs: torch.Tensor, + data_samples=None, + mode: str = 'tensor', + **kwargs): + """The unified entry for a forward process in both training and test. + + The method should accept three modes: "tensor", "predict" and "loss": + + - "tensor": Forward the whole network and return tensor or tuple of + tensor without any post-processing, same as a common nn.Module. + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`DetDataSample`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (list[:obj:`DetDataSample`], optional): The + annotation data of every samples. Defaults to None. + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor or a tuple of tensor. + - If ``mode="predict"``, return a list of :obj:`DetDataSample`. + - If ``mode="loss"``, return a dict of tensor. + """ + if mode == 'loss': + return self.loss(inputs, data_samples, **kwargs) + elif mode == 'predict': + return self.predict(inputs, data_samples, **kwargs) + elif mode == 'test': + return self.test(inputs, data_samples, **kwargs) + elif mode == 'tensor': + return self._forward(inputs, data_samples, **kwargs) + else: + raise RuntimeError(f'Invalid mode "{mode}". ' + 'Only supports loss, predict and tensor mode') + + def loss(self, inputs, data_samples=None): + """Calculate a loss given an input image and a desired token sequence, + the model will be trained in a teacher-forcing manner. + + Args: + inputs: (batch_size, num_channels, height, width) + decoder_input_ids: (batch_size, sequence_length, embedding_dim) + decode_labels: (batch_size, sequence_length) + """ + encoder_outputs = self.encoder(inputs) + + input_ids, labels = self.get_input_ids_train(data_samples) + input_ids = input_ids.to(encoder_outputs.device) + labels = labels.to(encoder_outputs.device) + + decoder_outputs = self.decoder.loss( + input_ids=input_ids, + encoder_hidden_states=encoder_outputs, + labels=labels, + data_samples=data_samples) + return decoder_outputs + + def predict(self, inputs, data_samples=None): + """Calculate a loss given an input image and a desired token sequence, + the model will be trained in a teacher-forcing manner. + + Args: + inputs: (batch_size, num_channels, height, width) + decoder_input_ids: (batch_size, sequence_length, embedding_dim) + decode_labels: (batch_size, sequence_length) + """ + encoder_outputs = self.encoder(inputs) + + decoder_input_ids, prompt_end_idxs, answers = self.get_input_ids_val( + data_samples) + + prompt_tensors = pad_sequence( + [ + input_id[:end_idx + 1] for input_id, end_idx in zip( + decoder_input_ids, prompt_end_idxs) + ], + batch_first=True, + ) + decoder_input_ids = decoder_input_ids.to(encoder_outputs.device) + + if len(encoder_outputs.size()) == 1: + encoder_outputs = encoder_outputs.unsqueeze(0) + + return_attentions = False + + if len(prompt_tensors.size()) == 1: + prompt_tensors = prompt_tensors.unsqueeze(0) + prompt_tensors = prompt_tensors.to(encoder_outputs.device) + + # get decoder output + encoder_outputs = ModelOutput( + last_hidden_state=encoder_outputs, attentions=None) + + decoder_output = self.decoder.model.generate( + decoder_input_ids=prompt_tensors, + encoder_outputs=encoder_outputs, + max_length=self.max_length, + early_stopping=True, + pad_token_id=self.decoder.tokenizer.pad_token_id, + eos_token_id=self.decoder.tokenizer.eos_token_id, + use_cache=True, + num_beams=1, + bad_words_ids=[[self.decoder.tokenizer.unk_token_id]], + return_dict_in_generate=True, + output_attentions=return_attentions, + ) + + output = {'predictions': list(), 'predictions_json': list()} + for i, seq in enumerate( + self.decoder.tokenizer.batch_decode(decoder_output.sequences)): + seq = seq.replace(self.decoder.tokenizer.eos_token, '') + seq = seq.replace(self.decoder.tokenizer.pad_token, '') + # remove first task start token + seq = re.sub(r'<.*?>', '', seq, count=1).strip() + output['predictions'].append(seq) + output['predictions_json'].append(self.token2json(seq)) + + answer = answers[i] + answer = re.sub(r'<.*?>', '', answer, count=1) + answer = answer.replace(self.decoder.tokenizer.eos_token, '') + + data_samples[i].pred_instances = InstanceData( + parses=[seq], parses_json=[self.token2json(seq)]) + data_samples[i].gt_instances['parses'] = [answer] + + if return_attentions: + output['attentions'] = { + 'self_attentions': decoder_output.decoder_attentions, + 'cross_attentions': decoder_output.cross_attentions, + } + + return data_samples + + def test(self, inputs, data_samples=None): + """Calculate a loss given an input image and a desired token sequence, + the model will be trained in a teacher-forcing manner. + + Args: + inputs: (batch_size, num_channels, height, width) + decoder_input_ids: (batch_size, sequence_length, embedding_dim) + decode_labels: (batch_size, sequence_length) + """ + encoder_outputs = self.encoder(inputs) + if len(encoder_outputs.size()) == 1: + encoder_outputs = encoder_outputs.unsqueeze(0) + encoder_outputs = ModelOutput( + last_hidden_state=encoder_outputs, attentions=None) + + prompt_tensors = self.decoder.tokenizer( + self.decoder.task_start_token, + add_special_tokens=False, + return_tensors='pt')['input_ids'] + if len(prompt_tensors.size()) == 1: + prompt_tensors = prompt_tensors.unsqueeze(0) + prompt_tensors = prompt_tensors.to(inputs.device) + + return_attentions = False + decoder_output = self.decoder.model.generate( + decoder_input_ids=prompt_tensors, + encoder_outputs=encoder_outputs, + max_length=self.max_length, + early_stopping=True, + pad_token_id=self.decoder.tokenizer.pad_token_id, + eos_token_id=self.decoder.tokenizer.eos_token_id, + use_cache=True, + num_beams=1, + bad_words_ids=[[self.decoder.tokenizer.unk_token_id]], + return_dict_in_generate=True, + output_attentions=return_attentions, + ) + + output = {'predictions': list(), 'predictions_json': list()} + for i, seq in enumerate( + self.decoder.tokenizer.batch_decode(decoder_output.sequences)): + seq = seq.replace(self.decoder.tokenizer.eos_token, '') + seq = seq.replace(self.decoder.tokenizer.pad_token, '') + # remove first task start token + seq = re.sub(r'<.*?>', '', seq, count=1).strip() + output['predictions'].append(seq) + output['predictions_json'].append(self.token2json(seq)) + + answer = data_samples[i].parses_json + data_samples[i].pred_instances = InstanceData( + parses=[seq], parses_json=[self.token2json(seq)]) + data_samples[i].gt_instances['parses_json'] = answer + + if return_attentions: + output['attentions'] = { + 'self_attentions': decoder_output.decoder_attentions, + 'cross_attentions': decoder_output.cross_attentions, + } + + return data_samples + + def json2token(self, + obj: Any, + update_special_tokens_for_json_key: bool = True, + sort_json_key: bool = True): + """Convert an ordered JSON object into a token sequence.""" + if type(obj) == dict: + if len(obj) == 1 and 'text_sequence' in obj: + return obj['text_sequence'] + else: + output = '' + if sort_json_key: + keys = sorted(obj.keys(), reverse=True) + else: + keys = obj.keys() + for k in keys: + if update_special_tokens_for_json_key: + self.decoder.add_special_tokens( + [fr'', fr'']) + output += (fr'' + self.json2token( + obj[k], update_special_tokens_for_json_key, + sort_json_key) + fr'') + return output + elif type(obj) == list: + return r''.join([ + self.json2token(item, update_special_tokens_for_json_key, + sort_json_key) for item in obj + ]) + else: + obj = str(obj) + if f'<{obj}/>' in self.decoder.tokenizer.all_special_tokens: + obj = f'<{obj}/>' # for categorical special tokens + return obj + + def token2json(self, tokens, is_inner_value=False): + """Convert a (generated) token seuqnce into an ordered JSON format.""" + output = dict() + + while tokens: + start_token = re.search(r'', tokens, re.IGNORECASE) + if start_token is None: + break + key = start_token.group(1) + end_token = re.search(fr'', tokens, re.IGNORECASE) + start_token = start_token.group() + if end_token is None: + tokens = tokens.replace(start_token, '') + else: + end_token = end_token.group() + start_token_escaped = re.escape(start_token) + end_token_escaped = re.escape(end_token) + content = re.search( + f'{start_token_escaped}(.*?){end_token_escaped}', tokens, + re.IGNORECASE) + if content is not None: + content = content.group(1).strip() + # non-leaf node + if r''): + leaf = leaf.strip() + if (leaf in + self.decoder.tokenizer.get_added_vocab() + and leaf[0] == '<' and leaf[-2:] == '/>'): + leaf = leaf[ + 1:-2] # for categorical special tokens + output[key].append(leaf) + if len(output[key]) == 1: + output[key] = output[key][0] + + tokens = tokens[tokens.find(end_token) + + len(end_token):].strip() + if tokens[:6] == r'': # non-leaf nodes + return [output] + self.token2json( + tokens[6:], is_inner_value=True) + + if len(output): + return [output] if is_inner_value else output + else: + return [] if is_inner_value else {'text_sequence': tokens} diff --git a/projects/Donut/donut/model/donut_preprocessor.py b/projects/Donut/donut/model/donut_preprocessor.py new file mode 100644 index 000000000..96e5987a2 --- /dev/null +++ b/projects/Donut/donut/model/donut_preprocessor.py @@ -0,0 +1,64 @@ +from numbers import Number +from typing import Dict, List, Optional, Sequence, Union + +import torch.nn as nn +from mmengine.model import ImgDataPreprocessor + +from mmocr.registry import MODELS + + +@MODELS.register_module() +class DonutDataPreprocessor(ImgDataPreprocessor): + + def __init__(self, + mean: Sequence[Number] = None, + std: Sequence[Number] = None, + pad_size_divisor: int = 1, + pad_value: Union[float, int] = 0, + bgr_to_rgb: bool = False, + rgb_to_bgr: bool = False, + batch_augments: Optional[List[Dict]] = None) -> None: + super().__init__( + mean=mean, + std=std, + pad_size_divisor=pad_size_divisor, + pad_value=pad_value, + bgr_to_rgb=bgr_to_rgb, + rgb_to_bgr=rgb_to_bgr) + if batch_augments is not None: + self.batch_augments = nn.ModuleList( + [MODELS.build(aug) for aug in batch_augments]) + else: + self.batch_augments = None + + def forward(self, data: Dict, training: bool = False) -> Dict: + """Perform normalization、padding and bgr2rgb conversion based on + ``BaseDataPreprocessor``. + + Args: + data (dict): Data sampled from dataloader. + training (bool): Whether to enable training time augmentation. + + Returns: + dict: Data in the same format as the model input. + """ + data = super().forward(data=data, training=training) + inputs, data_samples = data['inputs'], data['data_samples'] + + if data_samples is not None: + batch_input_shape = tuple(inputs[0].size()[-2:]) + for data_sample in data_samples: + + # valid_ratio = data_sample.valid_ratio * \ + # data_sample.img_shape[1] / batch_input_shape[1] + data_sample.set_metainfo( + dict( + # valid_ratio=valid_ratio, + batch_input_shape=batch_input_shape)) + + if training and self.batch_augments is not None: + for batch_aug in self.batch_augments: + inputs, data_samples = batch_aug(inputs, data_samples) + + # load decoder_input_ids, decoder_labels + return data diff --git a/projects/Donut/donut/model/swin_encoder.py b/projects/Donut/donut/model/swin_encoder.py new file mode 100644 index 000000000..c2c7f16dd --- /dev/null +++ b/projects/Donut/donut/model/swin_encoder.py @@ -0,0 +1,91 @@ +import math +from typing import List + +import timm +import torch +import torch.nn.functional as F +from mmengine.model import BaseModel +from timm.models.swin_transformer import SwinTransformer + +from mmocr.registry import MODELS + + +@MODELS.register_module() +class SwinEncoder(BaseModel): + r""" + Donut encoder based on SwinTransformer + Set the initial weights and configuration with a pretrained + SwinTransformer and then modify the detailed configurations + as a Donut Encoder + + Args: + input_size: Input image size (width, height) + align_long_axis: Whether to rotate image if height is + greater than width + window_size: Window size(=patch size) of SwinTransformer + encoder_layer: Number of layers of SwinTransformer encoder + """ + + def __init__(self, + input_size: List[int], + align_long_axis: bool, + window_size: int, + encoder_layer: List[int], + init_cfg=dict()): + super().__init__(init_cfg=init_cfg) + self.input_size = input_size + self.align_long_axis = align_long_axis + self.window_size = window_size + self.encoder_layer = encoder_layer + + self.model = SwinTransformer( + img_size=self.input_size, + depths=self.encoder_layer, + window_size=self.window_size, + patch_size=4, + embed_dim=128, + num_heads=[4, 8, 16, 32], + num_classes=0, + ) + self.model.norm = None + + def init_weights(self): + super().init_weights() + # weight init with swin + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + swin_state_dict = timm.create_model( + 'swin_base_patch4_window12_384', pretrained=True).state_dict() + new_swin_state_dict = self.model.state_dict() + for x in new_swin_state_dict: + if x.endswith('relative_position_index') or x.endswith( + 'attn_mask'): + pass + elif (x.endswith('relative_position_bias_table') + and self.model.layers[0].blocks[0].attn.window_size[0] != + 12): + pos_bias = swin_state_dict[x].unsqueeze(0)[0] + old_len = int(math.sqrt(len(pos_bias))) + new_len = int(2 * self.window_size - 1) + pos_bias = pos_bias.reshape(1, old_len, old_len, + -1).permute(0, 3, 1, 2) + pos_bias = F.interpolate( + pos_bias, + size=(new_len, new_len), + mode='bicubic', + align_corners=False) + new_swin_state_dict[x] = pos_bias.permute( + 0, 2, 3, 1).reshape(1, new_len**2, -1).squeeze(0) + else: + new_swin_state_dict[x] = swin_state_dict[x] + self.model.load_state_dict(new_swin_state_dict) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: (batch_size, num_channels, height, width) + """ + x = self.model.patch_embed(x) + x = self.model.pos_drop(x) + x = self.model.layers(x) + return x diff --git a/projects/Donut/requirements.txt b/projects/Donut/requirements.txt new file mode 100644 index 000000000..2206b8a34 --- /dev/null +++ b/projects/Donut/requirements.txt @@ -0,0 +1,3 @@ +nltk==3.8.1 +transformers==4.25.1 +zss==1.2.0 diff --git a/projects/Donut/tools/dataset_convert.py b/projects/Donut/tools/dataset_convert.py new file mode 100644 index 000000000..98ec68998 --- /dev/null +++ b/projects/Donut/tools/dataset_convert.py @@ -0,0 +1,59 @@ +import argparse +import json +import os +from io import BytesIO + +import tqdm +from datasets import load_dataset +from PIL import Image + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--data-dir', default='naver-clova-ix/cord-v2') + parser.add_argument('--save-dir', default='datasets/cord-v2') + args = parser.parse_args() + return args + + +def main(): + args = get_args() + print(args) + + data_dir = args.data_dir + data_save_dir = args.save_dir + + dataset = load_dataset(data_dir, split=None) + + img_id = 0 + for split in dataset.keys(): + split_save_dir = os.path.join(data_save_dir, split) + split_image_dir = os.path.join(split_save_dir, 'images') + if not os.path.exists(split_save_dir): + os.makedirs(split_save_dir) + os.makedirs(split_image_dir, exist_ok=True) + split_meta_save_path = os.path.join(split_save_dir, 'metadata.jsonl') + + metadata = [] + for sample in tqdm.tqdm(dataset[split]): + image = sample['image'] + if isinstance(image, dict): + image = Image.open(BytesIO(image['bytes'])) + image.save(os.path.join(split_image_dir, f'{img_id}.jpg')) + image_name = f'images/{img_id}.jpg' + ground_truth = sample['ground_truth'] + metadata.append( + json.dumps( + { + 'file_name': image_name, + 'ground_truth': ground_truth + }, + ensure_ascii=False)) + img_id += 1 + + with open(split_meta_save_path, 'w') as f: + f.write('\n'.join(metadata)) + + +if __name__ == '__main__': + main() diff --git a/projects/Donut/tools/model_convert.py b/projects/Donut/tools/model_convert.py new file mode 100644 index 000000000..3037b5b22 --- /dev/null +++ b/projects/Donut/tools/model_convert.py @@ -0,0 +1,49 @@ +import argparse +import os +from collections import OrderedDict + +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('model') + parser.add_argument('--save-dir', default='data') + args = parser.parse_args() + return args + + +def main(): + args = get_args() + print(args) + + # get model using: + # git clone --branch official \ + # https://huggingface.co/naver-clova-ix/donut-base + assert os.path.exists(args.model), args.model + assert args.model[-4:] == '.bin', 'the model name is pytorch_model.bin' + model_state_dict = torch.load(args.model) + + # extract weights + encoder_state_dict = OrderedDict() + decoder_state_dict = OrderedDict() + for k, v in model_state_dict.items(): + if k.startswith('encoder.'): + new_k = k[len('encoder.'):] + encoder_state_dict[new_k] = v + elif k.startswith('decoder.'): + new_k = k[len('decoder.'):] + decoder_state_dict[new_k] = v + + # save weights + if not os.path.exists(args.save_dir): + os.makedirs(args.save_dir) + + torch.save(encoder_state_dict, + os.path.join(args.save_dir, 'donut_base_encoder.pth')) + torch.save(decoder_state_dict, + os.path.join(args.save_dir, 'donut_base_decoder.pth')) + + +if __name__ == '__main__': + main()