Skip to content

Commit

Permalink
Support key parameters for OCR pipeline (#2810)
Browse files Browse the repository at this point in the history
* Support key parameters for OCR pipeline

* update OCR.yaml
  • Loading branch information
cuicheng01 authored Jan 10, 2025
1 parent 7d0d727 commit 9b6dded
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 87 deletions.
8 changes: 5 additions & 3 deletions api_examples/pipelines/test_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@

from paddlex import create_pipeline

pipeline = create_pipeline(pipeline="OCR")
pipeline = create_pipeline(pipeline="OCR", limit_side_len=320)

output = pipeline.predict(
"./test_samples/general_ocr_002.png",
use_doc_orientation_classify=True,
use_doc_unwarping=True,
use_textline_orientation=True,
use_doc_unwarping=False,
use_textline_orientation=False,
unclip_ratio=3.0,
limit_side_len=1920,
)
# output = pipeline.predict(
# "./test_samples/general_ocr_002.png",
Expand Down
11 changes: 9 additions & 2 deletions paddlex/configs/pipelines/OCR.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ pipeline_name: OCR

text_type: general

use_doc_preprocessor: True
use_textline_orientation: True
use_doc_preprocessor: False
use_textline_orientation: False

SubPipelines:
DocPreprocessor:
Expand All @@ -29,6 +29,13 @@ SubModules:
model_name: PP-OCRv4_mobile_det
model_dir: null
batch_size: 1
limit_side_len: 960
limit_type: max
thresh: 0.3
box_thresh: 0.6
max_candidates: 1000
unclip_ratio: 2.0
use_dilation: False
TextLineOrientation:
module_name: textline_orientation
model_name: PP-LCNet_x0_25_textline_ori
Expand Down
196 changes: 114 additions & 82 deletions paddlex/inference/pipelines_new/ocr/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,18 @@ class OCRPipeline(BasePipeline):
def __init__(
self,
config: Dict,
device: str = None,
pp_option: PaddlePredictorOption = None,
device: Optional[str] = None,
use_doc_orientation_classify: Optional[bool] = None,
use_doc_unwarping: Optional[bool] = None,
use_textline_orientation: Optional[bool] = None,
limit_side_len: Optional[int] = None,
limit_type: Optional[str] = None,
thresh: Optional[float] = None,
box_thresh: Optional[float] = None,
max_candidates: Optional[int] = None,
unclip_ratio: Optional[float] = None,
use_dilation: Optional[bool] = None,
pp_option: Optional[PaddlePredictorOption] = None,
use_hpip: bool = False,
hpi_params: Optional[Dict[str, Any]] = None,
) -> None:
Expand All @@ -43,16 +53,65 @@ def __init__(
Args:
config (Dict): Configuration dictionary containing model and other parameters.
device (str): The device to run the prediction on. Default is None.
pp_option (PaddlePredictorOption): Options for PaddlePaddle predictor. Default is None.
use_hpip (bool): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
hpi_params (Optional[Dict[str, Any]]): HPIP specific parameters. Default is None.
device (Union[str, None]): The device to run the prediction on.
use_textline_orientation (Union[bool, None]): Whether to use textline orientation.
use_doc_orientation_classify (Union[bool, None]): Whether to use document orientation classification.
use_doc_unwarping (Union[bool, None]): Whether to use document unwarping.
limit_side_len (Union[int, None]): Limit of side length.
limit_type (Union[str, None]): Type of limit.
thresh (Union[float, None]): Threshold value.
box_thresh (Union[float, None]): Box threshold value.
max_candidates (Union[int, None]): Maximum number of candidates.
unclip_ratio (Union[float, None]): Unclip ratio.
use_dilation (Union[bool, None]): Whether to use dilation.
pp_option (Union[PaddlePredictorOption, None]): Options for PaddlePaddle predictor.
use_hpip (Union[bool, None]): Whether to use high-performance inference.
hpi_params (Union[Dict[str, Any], None]): HPIP specific parameters.
"""
super().__init__(
device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
)

self.inintial_predictor(config)
self.use_textline_orientation = (
use_textline_orientation
if use_textline_orientation is not None
else config.get("use_textline_orientation", False)
)
self.use_doc_preprocessor = self.get_preprocessor_value(
use_doc_orientation_classify, use_doc_unwarping, config, False
)

text_det_default_params = {
"limit_side_len": 960,
"limit_type": "max",
"thresh": 0.3,
"box_thresh": 0.6,
"max_candidates": 1000,
"unclip_ratio": 2.0,
"use_dilation": False,
}

text_det_config = config["SubModules"]["TextDetection"]
for key, default_params in text_det_default_params.items():
text_det_config[key] = locals().get(
key, text_det_config.get(key, default_params)
)
self.text_det_model = self.create_model(text_det_config)

text_rec_config = config["SubModules"]["TextRecognition"]
self.text_rec_model = self.create_model(text_rec_config)

if self.use_textline_orientation:
textline_orientation_config = config["SubModules"]["TextLineOrientation"]
self.textline_orientation_model = self.create_model(
textline_orientation_config
)

if self.use_doc_preprocessor:
doc_preprocessor_config = config["SubPipelines"]["DocPreprocessor"]
self.doc_preprocessor_pipeline = self.create_pipeline(
doc_preprocessor_config
)

self.text_type = config["text_type"]

Expand All @@ -68,60 +127,15 @@ def __init__(
self.batch_sampler = ImageBatchSampler(batch_size=1)
self.img_reader = ReadImage(format="BGR")

def set_used_models_flag(self, config: Dict) -> None:
"""
Set the flags for which models to use based on the configuration.
Args:
config (Dict): A dictionary containing configuration settings.
Returns:
None
"""
pipeline_name = config["pipeline_name"]

self.pipeline_name = pipeline_name

self.use_doc_preprocessor = False

if "use_doc_preprocessor" in config:
self.use_doc_preprocessor = config["use_doc_preprocessor"]

self.use_textline_orientation = False

if "use_textline_orientation" in config:
self.use_textline_orientation = config["use_textline_orientation"]

def inintial_predictor(self, config: Dict) -> None:
"""Initializes the predictor based on the provided configuration.
Args:
config (Dict): A dictionary containing the configuration for the predictor.
Returns:
None
"""

self.set_used_models_flag(config)

text_det_model_config = config["SubModules"]["TextDetection"]
self.text_det_model = self.create_model(text_det_model_config)

text_rec_model_config = config["SubModules"]["TextRecognition"]
self.text_rec_model = self.create_model(text_rec_model_config)

if self.use_doc_preprocessor:
doc_preprocessor_config = config["SubPipelines"]["DocPreprocessor"]
self.doc_preprocessor_pipeline = self.create_pipeline(
doc_preprocessor_config
)
# Just for initialize the predictor
if self.use_textline_orientation:
textline_orientation_config = config["SubModules"]["TextLineOrientation"]
self.textline_orientation_model = self.create_model(
textline_orientation_config
)
return
@staticmethod
def get_preprocessor_value(orientation, unwarping, config, default):
if orientation is None and unwarping is None:
return config.get("use_doc_preprocessor", default)
else:
if orientation is False and unwarping is False:
return False
else:
return True

def rotate_image(
self, image_array_list: List[np.ndarray], rotate_angle_list: List[int]
Expand Down Expand Up @@ -159,25 +173,25 @@ def rotate_image(

return rotated_images

def check_input_params_valid(self, input_params: Dict) -> bool:
def check_model_settings_valid(self, model_settings: Dict) -> bool:
"""
Check if the input parameters are valid based on the initialized models.
Args:
input_params (Dict): A dictionary containing input parameters.
model_info_params(Dict): A dictionary containing input parameters.
Returns:
bool: True if all required models are initialized according to input parameters, False otherwise.
"""

if input_params["use_doc_preprocessor"] and not self.use_doc_preprocessor:
if model_settings["use_doc_preprocessor"] and not self.use_doc_preprocessor:
logging.error(
"Set use_doc_preprocessor, but the models for doc preprocessor are not initialized."
)
return False

if (
input_params["use_textline_orientation"]
model_settings["use_textline_orientation"]
and not self.use_textline_orientation
):
logging.error(
Expand Down Expand Up @@ -211,18 +225,23 @@ def predict_doc_preprocessor_res(
use_doc_unwarping=use_doc_unwarping,
)
)
doc_preprocessor_image = doc_preprocessor_res["output_img"]
else:
doc_preprocessor_res = {}
doc_preprocessor_image = image_array
return doc_preprocessor_res, doc_preprocessor_image
doc_preprocessor_res = {"output_img": image_array}
return doc_preprocessor_res

def predict(
self,
input: str | list[str] | np.ndarray | list[np.ndarray],
use_doc_orientation_classify: bool = False,
use_doc_unwarping: bool = False,
use_textline_orientation: bool = False,
limit_side_len: int = 960,
limit_type: str = "max",
thresh: float = 0.3,
box_thresh: float = 0.6,
max_candidates: int = 1000,
unclip_ratio: float = 2.0,
use_dilation: bool = False,
**kwargs,
) -> OCRResult:
"""Predicts OCR results for the given input.
Expand All @@ -235,44 +254,56 @@ def predict(
OCRResult: An iterable of OCRResult objects, each containing the predicted text and other relevant information.
"""

input_params = {
"use_doc_preprocessor": self.use_doc_preprocessor,
model_settings = {
"use_doc_orientation_classify": use_doc_orientation_classify,
"use_doc_unwarping": use_doc_unwarping,
"use_textline_orientation": self.use_textline_orientation,
"use_textline_orientation": use_textline_orientation,
}
if use_doc_orientation_classify or use_doc_unwarping:
input_params["use_doc_preprocessor"] = True
model_settings["use_doc_preprocessor"] = True
else:
input_params["use_doc_preprocessor"] = False
model_settings["use_doc_preprocessor"] = False

if not self.check_input_params_valid(input_params):
if not self.check_model_settings_valid(model_settings):
yield None

text_det_params = {
"limit_side_len": limit_side_len,
"limit_type": limit_type,
"thresh": thresh,
"box_thresh": box_thresh,
"max_candidates": max_candidates,
"unclip_ratio": unclip_ratio,
"use_dilation": use_dilation,
}

for img_id, batch_data in enumerate(self.batch_sampler(input)):
image_array = self.img_reader(batch_data)[0]
img_id += 1

doc_preprocessor_res, doc_preprocessor_image = (
self.predict_doc_preprocessor_res(image_array, input_params)
doc_preprocessor_res = self.predict_doc_preprocessor_res(
image_array, model_settings
)
doc_preprocessor_image = doc_preprocessor_res["output_img"]

det_res = next(self.text_det_model(doc_preprocessor_image))
det_res = next(
self.text_det_model(doc_preprocessor_image, **text_det_params)
)

dt_polys = det_res["dt_polys"]
dt_scores = det_res["dt_scores"]

########## [TODO] Need to confirm filtering thresholds for detection and recognition modules

dt_polys = self._sort_boxes(dt_polys)

single_img_res = {
"input_path": input,
# TODO: `doc_preprocessor_image` parameter does not need to be retained here, it requires further confirmation.
"doc_preprocessor_image": doc_preprocessor_image,
"doc_preprocessor_res": doc_preprocessor_res,
"dt_polys": dt_polys,
"img_id": img_id,
"input_params": input_params,
"input_params": model_settings,
"text_det_params": text_det_params,
"text_type": self.text_type,
}

Expand All @@ -283,13 +314,14 @@ def predict(
self._crop_by_polys(doc_preprocessor_image, dt_polys)
)
# use textline orientation model
if input_params["use_textline_orientation"]:
if model_settings["use_textline_orientation"]:
angles = [
textline_angle_info["class_ids"][0]
for textline_angle_info in self.textline_orientation_model(
all_subs_of_img
)
]
single_img_res["textline_orientation_angle"] = angles
all_subs_of_img = self.rotate_image(all_subs_of_img, angles)

for rec_res in self.text_rec_model(all_subs_of_img):
Expand Down

0 comments on commit 9b6dded

Please sign in to comment.