Skip to content

Commit 60a1f4f

Browse files
committed
Remove unnecessary preprocessing steps of VI-LayoutXLM to improve performance.
1 parent 5120a2a commit 60a1f4f

File tree

3 files changed

+19
-60
lines changed

3 files changed

+19
-60
lines changed

configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yaml

+5-29
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,9 @@ train:
5959
label_file: XFUND/zh_train/train.json
6060
sample_ratio: 1.0
6161
transform_pipeline:
62-
- DecodeImage:
63-
img_mode: RGB
64-
to_float32: False
6562
- VQATokenLabelEncode:
6663
contains_re: True
67-
algorithm: &algorithm LayoutXLM
64+
algorithm: &algorithm VI-LayoutXLM
6865
class_path: *class_path
6966
order_method: tb-yx
7067
- VQATokenPad:
@@ -75,30 +72,21 @@ train:
7572
max_seq_len: *max_seq_len
7673
- TensorizeEntitiesRelations:
7774
max_relation_len: 5000
78-
- LayoutResize:
79-
size: [224, 224]
80-
- NormalizeImage:
81-
bgr_to_rgb: False
82-
is_hwc: True
83-
mean: imagenet
84-
std: imagenet
85-
- ToCHWImage:
8675
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visualize
8776
output_columns:
8877
[
8978
"input_ids",
9079
"bbox",
9180
"attention_mask",
9281
"token_type_ids",
93-
"image",
9482
"question",
9583
"question_label",
9684
"answer",
9785
"answer_label",
9886
"relation_label",
9987
]
100-
net_input_column_index: [0, 1, 2, 3, 4, 5, 6, 7, 8] # input indices for network forward func in output_columns
101-
label_column_index: [9] # input indices marked as label
88+
net_input_column_index: [0, 1, 2, 3, 4, 5, 6, 7] # input indices for network forward func in output_columns
89+
label_column_index: [8] # input indices marked as label
10290

10391
loader:
10492
shuffle: True
@@ -117,9 +105,6 @@ eval:
117105
sample_ratio: 1.0
118106
shuffle: False
119107
transform_pipeline:
120-
- DecodeImage:
121-
img_mode: RGB
122-
to_float32: False
123108
- VQATokenLabelEncode:
124109
contains_re: True
125110
algorithm: *algorithm
@@ -133,30 +118,21 @@ eval:
133118
max_seq_len: *max_seq_len
134119
- TensorizeEntitiesRelations:
135120
max_relation_len: 5000
136-
- LayoutResize:
137-
size: [224, 224]
138-
- NormalizeImage:
139-
bgr_to_rgb: False
140-
is_hwc: True
141-
mean: imagenet
142-
std: imagenet
143-
- ToCHWImage:
144121
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visualize
145122
output_columns:
146123
[
147124
"input_ids",
148125
"bbox",
149126
"attention_mask",
150127
"token_type_ids",
151-
"image",
152128
"question",
153129
"question_label",
154130
"answer",
155131
"answer_label",
156132
"relation_label",
157133
]
158-
net_input_column_index: [0, 1, 2, 3, 4, 5, 6, 7, 8] # input indices for network forward func in output_columns
159-
label_column_index: [9] # input indices marked as label
134+
net_input_column_index: [0, 1, 2, 3, 4, 5, 6, 7] # input indices for network forward func in output_columns
135+
label_column_index: [8] # input indices marked as label
160136

161137
loader:
162138
shuffle: False

configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yaml

+7-29
Original file line numberDiff line numberDiff line change
@@ -57,31 +57,20 @@ train:
5757
label_file: XFUND/zh_train/train.json
5858
sample_ratio: 1.0
5959
transform_pipeline:
60-
- DecodeImage:
61-
img_mode: RGB
62-
to_float32: False
6360
- VQATokenLabelEncode:
6461
contains_re: False
65-
algorithm: &algorithm LayoutXLM
62+
algorithm: &algorithm VI-LayoutXLM
6663
class_path: *class_path
6764
order_method: tb-yx
6865
- VQATokenPad:
6966
max_seq_len: &max_seq_len 512
7067
return_attention_mask: True
7168
- VQASerTokenChunk:
7269
max_seq_len: *max_seq_len
73-
- LayoutResize:
74-
size: [ 224, 224 ]
75-
- NormalizeImage:
76-
bgr_to_rgb: False
77-
is_hwc: True
78-
mean: imagenet
79-
std: imagenet
80-
- ToCHWImage:
8170
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visualize
82-
output_columns: [ 'input_ids', 'bbox','attention_mask','token_type_ids', 'image', 'labels' ]
83-
net_input_column_index: [ 0, 1, 2, 3, 4 ] # input indices for network forward func in output_columns
84-
label_column_index: [ 2, 5 ] # input indices marked as label
71+
output_columns: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'labels' ]
72+
net_input_column_index: [ 0, 1, 2, 3 ] # input indices for network forward func in output_columns
73+
label_column_index: [ 2, 4 ] # input indices marked as label
8574

8675
loader:
8776
shuffle: True
@@ -100,9 +89,6 @@ eval:
10089
sample_ratio: 1.0
10190
shuffle: False
10291
transform_pipeline:
103-
- DecodeImage:
104-
img_mode: RGB
105-
to_float32: False
10692
- VQATokenLabelEncode:
10793
contains_re: False
10894
algorithm: *algorithm
@@ -113,18 +99,10 @@ eval:
11399
return_attention_mask: True
114100
- VQASerTokenChunk:
115101
max_seq_len: *max_seq_len
116-
- LayoutResize:
117-
size: [ 224, 224 ]
118-
- NormalizeImage:
119-
bgr_to_rgb: False
120-
is_hwc: True
121-
mean: imagenet
122-
std: imagenet
123-
- ToCHWImage:
124102
# the order of the dataloader list, matching the network input and the labels for evaluation
125-
output_columns: [ 'input_ids', 'bbox', 'attention_mask','token_type_ids','image', 'labels' ]
126-
net_input_column_index: [ 0, 1, 2, 3, 4 ] # input indices for network forward func in output_columns
127-
label_column_index: [ 2, 5 ] # input indices marked as label
103+
output_columns: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'labels' ]
104+
net_input_column_index: [ 0, 1, 2, 3 ] # input indices for network forward func in output_columns
105+
label_column_index: [ 2, 4 ] # input indices marked as label
128106

129107
loader:
130108
shuffle: False

mindocr/data/transforms/layoutlm_transforms.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import cv2
66
import numpy as np
7+
from PIL import Image
78

89
from mindspore import nn
910

@@ -65,8 +66,10 @@ def __init__(
6566
super(VQATokenLabelEncode, self).__init__()
6667
tokenizer_dict = {
6768
"LayoutXLM": {"class": LayoutXLMTokenizer, "pretrained_model": "layoutxlm-base-uncased"},
69+
"VI-LayoutXLM": {"class": LayoutXLMTokenizer, "pretrained_model": "layoutxlm-base-uncased"},
6870
}
6971
self.contains_re = contains_re
72+
self.algorithm = algorithm
7073
tokenizer_config = tokenizer_dict[algorithm]
7174
self.tokenizer = tokenizer_config["class"].from_pretrained(tokenizer_config["pretrained_model"]) # to replace
7275
self.label2id_map, id2label_map = load_vqa_bio_label_maps(class_path)
@@ -141,8 +144,10 @@ def __call__(self, data):
141144
train_re = self.contains_re and not self.infer_mode
142145
if train_re:
143146
ocr_info = self.filter_empty_contents(ocr_info)
144-
145-
height, width, _ = data["image"].shape
147+
if self.algorithm == "VI-LayoutXLM":
148+
width, height = Image.open(data["img_path"]).size
149+
elif self.algorithm == "LayoutXLM":
150+
height, width, _ = data["image"].shape
146151

147152
words_list = []
148153
bbox_list = []

0 commit comments

Comments
 (0)