diff --git a/api_examples/pipelines/test_formula_recognition.py b/api_examples/pipelines/test_formula_recognition.py new file mode 100644 index 000000000..8073f3f8b --- /dev/null +++ b/api_examples/pipelines/test_formula_recognition.py @@ -0,0 +1,43 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddlex import create_pipeline + +pipeline = create_pipeline(pipeline="formula_recognition") + +output = pipeline.predict( + "./test_samples/general_formula_recognition01.png", use_layout_detection=True +) + +# output = pipeline.predict( +# "./test_samples/general_formula_recognition01.pdf", +# use_layout_detection=True, +# ) + +# output = pipeline.predict( +# "./test_samples/general_formula_recognition02.png", +# use_layout_detection=False, +# ) + +# img_list = [ "./test_samples/general_formula_recognition03.png", \ +# "./test_samples/general_formula_recognition04.png", \ +# "./test_samples/general_formula_recognition05.png",] +# output = pipeline.predict( +# img_list, +# use_layout_detection=True, +# ) + +for res in output: + # res.save_to_img("./output/") + res.save_results("./output") diff --git a/api_examples/pipelines/test_image_classification.py b/api_examples/pipelines/test_image_classification.py index 771c7c7a2..740e67c2d 100644 --- a/api_examples/pipelines/test_image_classification.py +++ b/api_examples/pipelines/test_image_classification.py @@ -16,7 +16,7 @@ pipeline = create_pipeline(pipeline="image_classification") -output = pipeline.predict("./test_samples/general_image_classification_001.jpg") +output = pipeline.predict("./test_samples/general_image_classification_001.jpg", topk=5) # output = pipeline.predict("./test_samples/财报1.pdf") diff --git a/api_examples/pipelines/test_ocr.py b/api_examples/pipelines/test_ocr.py index 41775f723..ae4505a39 100644 --- a/api_examples/pipelines/test_ocr.py +++ b/api_examples/pipelines/test_ocr.py @@ -14,13 +14,15 @@ from paddlex import create_pipeline -pipeline = create_pipeline(pipeline="OCR") +pipeline = create_pipeline(pipeline="OCR", limit_side_len=320) output = pipeline.predict( "./test_samples/general_ocr_002.png", use_doc_orientation_classify=True, - use_doc_unwarping=True, - use_textline_orientation=True, + use_doc_unwarping=False, + use_textline_orientation=False, + unclip_ratio=3.0, + limit_side_len=1920, ) # output = pipeline.predict( # "./test_samples/general_ocr_002.png", diff --git a/api_examples/pipelines/test_pedestrian_attribute_rec.py b/api_examples/pipelines/test_pedestrian_attribute_rec.py new file mode 100644 index 000000000..ad9b4b9cf --- /dev/null +++ b/api_examples/pipelines/test_pedestrian_attribute_rec.py @@ -0,0 +1,26 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddlex import create_pipeline + +pipeline = create_pipeline(pipeline="pedestrian_attribute_recognition") + +output = pipeline.predict( + "./test_samples/pedestrian_attribute_002.jpg", det_threshold=0.7, cls_threshold=0.7 +) + +for res in output: + res.print() ## 打印预测的结构化输出 + res.save_to_img("./output") ## 保存结果可视化图像 + res.save_to_json("./output/") ## 保存预测的结构化输出 diff --git a/api_examples/pipelines/test_vehicle_attribute_rec.py b/api_examples/pipelines/test_vehicle_attribute_rec.py new file mode 100644 index 000000000..7a6167e81 --- /dev/null +++ b/api_examples/pipelines/test_vehicle_attribute_rec.py @@ -0,0 +1,26 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddlex import create_pipeline + +pipeline = create_pipeline(pipeline="vehicle_attribute_recognition") + +output = pipeline.predict( + "./test_samples/vehicle_attribute_002.jpg", det_threshold=0.7, cls_threshold=0.7 +) + +for res in output: + res.print() ## 打印预测的结构化输出 + res.save_to_img("./output") ## 保存结果可视化图像 + res.save_to_json("./output/") ## 保存预测的结构化输出 diff --git a/docs/pipeline_usage/tutorials/cv_pipelines/pedestrian_attribute_recognition.en.md b/docs/pipeline_usage/tutorials/cv_pipelines/pedestrian_attribute_recognition.en.md index d47a9146b..da370d60b 100644 --- a/docs/pipeline_usage/tutorials/cv_pipelines/pedestrian_attribute_recognition.en.md +++ b/docs/pipeline_usage/tutorials/cv_pipelines/pedestrian_attribute_recognition.en.md @@ -48,7 +48,7 @@ Pedestrian attribute recognition is a key function in computer vision systems, u ModelModel Download Link -mA (%) +mAP (%) GPU Inference Time (ms) CPU Inference Time (ms) Model Size (M) diff --git a/docs/pipeline_usage/tutorials/cv_pipelines/pedestrian_attribute_recognition.md b/docs/pipeline_usage/tutorials/cv_pipelines/pedestrian_attribute_recognition.md index 3f0aada48..55cada58b 100644 --- a/docs/pipeline_usage/tutorials/cv_pipelines/pedestrian_attribute_recognition.md +++ b/docs/pipeline_usage/tutorials/cv_pipelines/pedestrian_attribute_recognition.md @@ -48,7 +48,7 @@ comments: true 模型模型下载链接 -mA(%) +mAP(%) GPU推理耗时(ms) CPU推理耗时 (ms) 模型存储大小(M) diff --git a/docs/pipeline_usage/tutorials/cv_pipelines/vehicle_attribute_recognition.en.md b/docs/pipeline_usage/tutorials/cv_pipelines/vehicle_attribute_recognition.en.md index de48d9bc4..103d436bb 100644 --- a/docs/pipeline_usage/tutorials/cv_pipelines/vehicle_attribute_recognition.en.md +++ b/docs/pipeline_usage/tutorials/cv_pipelines/vehicle_attribute_recognition.en.md @@ -44,7 +44,7 @@ Vehicle attribute recognition is a crucial component in computer vision systems. ModelModel Download Link -mA (%) +mAP (%) GPU Inference Time (ms) CPU Inference Time (ms) Model Size (M) diff --git a/docs/pipeline_usage/tutorials/cv_pipelines/vehicle_attribute_recognition.md b/docs/pipeline_usage/tutorials/cv_pipelines/vehicle_attribute_recognition.md index 9df99ffef..795ec5957 100644 --- a/docs/pipeline_usage/tutorials/cv_pipelines/vehicle_attribute_recognition.md +++ b/docs/pipeline_usage/tutorials/cv_pipelines/vehicle_attribute_recognition.md @@ -45,7 +45,7 @@ comments: true 模型模型下载链接 -mA(%) +mAP(%) GPU推理耗时(ms) CPU推理耗时 (ms) 模型存储大小(M) diff --git a/paddlex/configs/pipelines/OCR.yaml b/paddlex/configs/pipelines/OCR.yaml index e1f60bb48..e97ebb143 100644 --- a/paddlex/configs/pipelines/OCR.yaml +++ b/paddlex/configs/pipelines/OCR.yaml @@ -3,8 +3,8 @@ pipeline_name: OCR text_type: general -use_doc_preprocessor: True -use_textline_orientation: True +use_doc_preprocessor: False +use_textline_orientation: False SubPipelines: DocPreprocessor: @@ -29,6 +29,13 @@ SubModules: model_name: PP-OCRv4_mobile_det model_dir: null batch_size: 1 + limit_side_len: 960 + limit_type: max + thresh: 0.3 + box_thresh: 0.6 + max_candidates: 1000 + unclip_ratio: 2.0 + use_dilation: False TextLineOrientation: module_name: textline_orientation model_name: PP-LCNet_x0_25_textline_ori diff --git a/paddlex/configs/pipelines/PP-ChatOCRv4-doc.yaml b/paddlex/configs/pipelines/PP-ChatOCRv4-doc.yaml index fa8f1b564..4a720f527 100644 --- a/paddlex/configs/pipelines/PP-ChatOCRv4-doc.yaml +++ b/paddlex/configs/pipelines/PP-ChatOCRv4-doc.yaml @@ -18,6 +18,13 @@ SubModules: ak: "api_key" # Set this to a real API key sk: "secret_key" # Set this to a real secret key + MLLM_Chat: + module_name: chat_bot + model_name: PP-DocBee + base_url: "http://127.0.0.1/v1/chat/completions" + api_type: openai + api_key: "api_key" + PromptEngneering: KIE_CommonText: module_name: prompt_engneering diff --git a/paddlex/configs/pipelines/formula_recognition.yaml b/paddlex/configs/pipelines/formula_recognition.yaml new file mode 100644 index 000000000..4cb859652 --- /dev/null +++ b/paddlex/configs/pipelines/formula_recognition.yaml @@ -0,0 +1,35 @@ + +pipeline_name: formula_recognition + +use_layout_detection: True +use_doc_preprocessor: True + +SubModules: + LayoutDetection: + module_name: layout_detection + model_name: RT-DETR-H_layout_17cls + model_dir: null + batch_size: 1 + + FormulaRecognition: + module_name: formula_recognition + model_name: PP-FormulaNet-L + model_dir: null + batch_size: 5 + +SubPipelines: + DocPreprocessor: + pipeline_name: doc_preprocessor + use_doc_orientation_classify: True + use_doc_unwarping: True + SubModules: + DocOrientationClassify: + module_name: doc_text_orientation + model_name: PP-LCNet_x1_0_doc_ori + model_dir: null + batch_size: 1 + DocUnwarping: + module_name: image_unwarping + model_name: UVDoc + model_dir: null + batch_size: 1 diff --git a/paddlex/configs/pipelines/pedestrian_attribute_recognition.yaml b/paddlex/configs/pipelines/pedestrian_attribute_recognition.yaml new file mode 100644 index 000000000..8e205942e --- /dev/null +++ b/paddlex/configs/pipelines/pedestrian_attribute_recognition.yaml @@ -0,0 +1,15 @@ +pipeline_name: pedestrian_attribute_recognition + +SubModules: + Detection: + module_name: object_detection + model_name: PP-YOLOE-L_human + model_dir: null + batch_size: 1 + threshold: 0.5 + Classification: + module_name: multilabel_classification + model_name: PP-LCNet_x1_0_pedestrian_attribute + model_dir: null + batch_size: 1 + threshold: 0.5 diff --git a/paddlex/configs/pipelines/vehicle_attribute_recognition.yaml b/paddlex/configs/pipelines/vehicle_attribute_recognition.yaml new file mode 100644 index 000000000..fbe19d1dc --- /dev/null +++ b/paddlex/configs/pipelines/vehicle_attribute_recognition.yaml @@ -0,0 +1,15 @@ +pipeline_name: vehicle_attribute_recognition + +SubModules: + Detection: + module_name: object_detection + model_name: PP-YOLOE-L_vehicle + model_dir: null + batch_size: 1 + threshold: 0.5 + Classification: + module_name: multilabel_classification + model_name: PP-LCNet_x1_0_vehicle_attribute + model_dir: null + batch_size: 1 + threshold: 0.5 diff --git a/paddlex/inference/models_new/formula_recognition/result.py b/paddlex/inference/models_new/formula_recognition/result.py index f40a3280b..16edaf7f2 100644 --- a/paddlex/inference/models_new/formula_recognition/result.py +++ b/paddlex/inference/models_new/formula_recognition/result.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import sys +import os, sys +from typing import Any, Dict, Optional, List import cv2 import PIL +import fitz import math import random import tempfile @@ -27,6 +28,7 @@ from ...common.result import BaseCVResult from ....utils import logging from ....utils.fonts import PINGFANG_FONT_FILE_PATH +from ....utils.file_interface import custom_open class FormulaRecResult(BaseCVResult): @@ -35,8 +37,18 @@ def _to_str(self, *args, **kwargs): def _to_img( self, - ): - """Draw formula on image""" + ) -> Image.Image: + """ + Draws a recognized formula on an image. + + This method processes an input image to recognize and render a LaTeX formula. + It overlays the rendered formula onto the input image and returns the combined image. + If the LaTeX rendering engine is not installed or a syntax error is detected, + it logs a warning and returns the original image. + + Returns: + Image.Image: An image with the recognized formula rendered alongside the original image. + """ image = Image.fromarray(self["input_img"]) try: env_valid() @@ -77,7 +89,19 @@ def _to_img( return {"res": image} -def get_align_equation(equation): +def get_align_equation(equation: str) -> str: + """ + Wraps an equation in LaTeX environment tags if not already aligned. + + This function checks if a given LaTeX equation contains any alignment tags (`align` or `align*`). + If the equation does not contain these tags, it wraps the equation in `equation` and `nonumber` tags. + + Args: + equation (str): The LaTeX equation to be checked and potentially modified. + + Returns: + str: The modified equation with appropriate LaTeX tags for alignment. + """ is_align = False equation = str(equation) + "\n" begin_dict = [ @@ -101,8 +125,19 @@ def get_align_equation(equation): return equation -def generate_tex_file(tex_file_path, equation): - with open(tex_file_path, "w") as fp: +def generate_tex_file(tex_file_path: str, equation: str) -> None: + """ + Generates a LaTeX file containing a specific equation. + + This function creates a LaTeX file at the specified file path, writing the necessary + LaTeX preamble and wrapping the provided equation in a document structure. The equation + is processed to ensure it includes alignment tags if necessary. + + Args: + tex_file_path (str): The file path where the LaTeX file will be saved. + equation (str): The LaTeX equation to be written into the file. + """ + with custom_open(tex_file_path, "w") as fp: start_template = ( r"\documentclass{article}" + "\n" r"\usepackage{cite}" + "\n" @@ -121,7 +156,24 @@ def generate_tex_file(tex_file_path, equation): fp.write(end_template) -def generate_pdf_file(tex_path, pdf_dir, is_debug=False): +def generate_pdf_file( + tex_path: str, pdf_dir: str, is_debug: bool = False +) -> Optional[bool]: + """ + Generates a PDF file from a LaTeX file using pdflatex. + + This function checks if the specified LaTeX file exists, and then runs pdflatex to generate a PDF file + in the specified directory. It can run in debug mode to show detailed output or in silent mode. + + Args: + tex_path (str): The path to the LaTeX file. + pdf_dir (str): The directory where the PDF file will be saved. + is_debug (bool, optional): If True, runs pdflatex with detailed output. Defaults to False. + + Returns: + Optional[bool]: Returns True if the PDF was generated successfully, False if the LaTeX file does not exist, + and None if an error occurred during the pdflatex execution. + """ if os.path.exists(tex_path): command = "pdflatex -halt-on-error -output-directory={} {}".format( pdf_dir, tex_path @@ -129,13 +181,27 @@ def generate_pdf_file(tex_path, pdf_dir, is_debug=False): if is_debug: subprocess.check_call(command, shell=True) else: - devNull = open(os.devnull, "w") + devNull = custom_open(os.devnull, "w") subprocess.check_call( command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True ) -def crop_white_area(image): +def crop_white_area(image: np.ndarray) -> Optional[List[int]]: + """ + Finds and returns the bounding box of the non-white area in an image. + + This function converts an image to grayscale and uses binary thresholding to + find contours. It then calculates the bounding rectangle around the non-white + areas of the image. + + Args: + image (np.ndarray): The input image as a NumPy array. + + Returns: + Optional[List[int]]: A list [x, y, w, h] representing the bounding box of + the non-white area, or None if no such area is found. + """ image = np.array(image).astype("uint8") gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) _, thresh = cv2.threshold(gray, 240, 255, cv2.THRESH_BINARY_INV) @@ -147,8 +213,18 @@ def crop_white_area(image): return None -def pdf2img(pdf_path, img_path, is_padding=False): - import fitz +def pdf2img(pdf_path: str, img_path: str, is_padding: bool = False): + """ + Converts a single-page PDF to an image, optionally cropping white areas and adding padding. + + Args: + pdf_path (str): The path to the PDF file. + img_path (str): The path where the image will be saved. + is_padding (bool): If True, adds a 30-pixel white padding around the image. + + Returns: + np.ndarray: The resulting image as a NumPy array, or None if the PDF is not single-page. + """ pdfDoc = fitz.open(pdf_path) if pdfDoc.page_count != 1: @@ -160,11 +236,10 @@ def pdf2img(pdf_path, img_path, is_padding=False): zoom_y = 2 mat = fitz.Matrix(zoom_x, zoom_y).prerotate(rotate) pix = page.get_pixmap(matrix=mat, alpha=False) - if not os.path.exists(img_path): - os.makedirs(img_path) - - pix._writeIMG(img_path, 7, 100) - img = cv2.imread(img_path) + getpngdata = pix.tobytes(output="png") + # decode as np.uint8 + image_array = np.frombuffer(getpngdata, dtype=np.uint8) + img = cv2.imdecode(image_array, cv2.IMREAD_ANYCOLOR) xywh = crop_white_area(img) if xywh is not None: @@ -178,8 +253,21 @@ def pdf2img(pdf_path, img_path, is_padding=False): return None -def draw_formula_module(img_size, box, formula, is_debug=False): - """draw box formula for module""" +def draw_formula_module( + img_size: tuple, box: list, formula: str, is_debug: bool = False +): + """ + Draw box formula for module. + + Args: + img_size (tuple): The size of the image as (width, height). + box (list): The coordinates for the bounding box. + formula (str): The LaTeX formula to render. + is_debug (bool): If True, retains intermediate files for debugging purposes. + + Returns: + np.ndarray: The resulting image with the formula or an error message. + """ box_width, box_height = img_size with tempfile.TemporaryDirectory() as td: tex_file_path = os.path.join(td, "temp.tex") @@ -200,7 +288,13 @@ def draw_formula_module(img_size, box, formula, is_debug=False): return img_right_text -def env_valid(): +def env_valid() -> bool: + """ + Validates if the environment is correctly set up to convert LaTeX formulas to images. + + Returns: + bool: True if the environment is valid and the conversion is successful, False otherwise. + """ with tempfile.TemporaryDirectory() as td: tex_file_path = os.path.join(td, "temp.tex") pdf_file_path = os.path.join(td, "temp.pdf") @@ -214,55 +308,19 @@ def env_valid(): formula_img = pdf2img(pdf_file_path, img_file_path, is_padding=False) -def draw_box_formula_fine(img_size, box, formula, is_debug=False): - """draw box formula for pipeline""" - box_height = int( - math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2) - ) - box_width = int( - math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2) - ) - with tempfile.TemporaryDirectory() as td: - tex_file_path = os.path.join(td, "temp.tex") - pdf_file_path = os.path.join(td, "temp.pdf") - img_file_path = os.path.join(td, "temp.jpg") - generate_tex_file(tex_file_path, formula) - if os.path.exists(tex_file_path): - generate_pdf_file(tex_file_path, td, is_debug) - formula_img = None - if os.path.exists(pdf_file_path): - formula_img = pdf2img(pdf_file_path, img_file_path, is_padding=False) - if formula_img is not None: - formula_h, formula_w = formula_img.shape[:-1] - resize_height = box_height - resize_width = formula_w * resize_height / formula_h - formula_img = cv2.resize( - formula_img, (int(resize_width), int(resize_height)) - ) - formula_h, formula_w = formula_img.shape[:-1] - pts1 = np.float32( - [[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]] - ) - pts2 = np.array(box, dtype=np.float32) - M = cv2.getPerspectiveTransform(pts1, pts2) - formula_img = np.array(formula_img, dtype=np.uint8) - img_right_text = cv2.warpPerspective( - formula_img, - M, - img_size, - flags=cv2.INTER_NEAREST, - borderMode=cv2.BORDER_CONSTANT, - borderValue=(255, 255, 255), - ) - else: - img_right_text = draw_box_txt_fine( - img_size, box, "Rendering Failed", PINGFANG_FONT_FILE_PATH - ) - return img_right_text +def draw_box_txt_fine(img_size: tuple, box: list, txt: str, font_path: str): + """ + Draw box text. + Args: + img_size (tuple): Size of the image as (width, height). + box (list): List of four points defining the box, each point is a tuple (x, y). + txt (str): The text to draw inside the box. + font_path (str): Path to the font file to be used for drawing text. -def draw_box_txt_fine(img_size, box, txt, font_path): - """draw box text""" + Returns: + np.ndarray: Image array with the text drawn and transformed to fit the box. + """ box_height = int( math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2) ) @@ -302,8 +360,18 @@ def draw_box_txt_fine(img_size, box, txt, font_path): return img_right_text -def create_font(txt, sz, font_path): - """create font""" +def create_font(txt: str, sz: tuple, font_path: str) -> ImageFont.FreeTypeFont: + """ + Creates a font object with a size that ensures the text fits within the specified dimensions. + + Args: + txt (str): The text to fit. + sz (tuple): The target size as (width, height). + font_path (str): The path to the font file. + + Returns: + ImageFont.FreeTypeFont: A PIL font object at the appropriate size. + """ font_size = int(sz[1] * 0.8) font = ImageFont.truetype(font_path, font_size, encoding="utf-8") if int(PIL.__version__.split(".")[0]) < 10: diff --git a/paddlex/inference/models_new/image_classification/predictor.py b/paddlex/inference/models_new/image_classification/predictor.py index a01f9daf2..fec69bae0 100644 --- a/paddlex/inference/models_new/image_classification/predictor.py +++ b/paddlex/inference/models_new/image_classification/predictor.py @@ -114,7 +114,8 @@ def process( """ batch_raw_imgs = self.preprocessors["Read"](imgs=batch_data) batch_imgs = self.preprocessors["Resize"](imgs=batch_raw_imgs) - batch_imgs = self.preprocessors["Crop"](imgs=batch_imgs) + if "Crop" in self.preprocessors: + batch_imgs = self.preprocessors["Crop"](imgs=batch_imgs) batch_imgs = self.preprocessors["Normalize"](imgs=batch_imgs) batch_imgs = self.preprocessors["ToCHW"](imgs=batch_imgs) x = self.preprocessors["ToBatch"](imgs=batch_imgs) diff --git a/paddlex/inference/models_new/table_structure_recognition/processors.py b/paddlex/inference/models_new/table_structure_recognition/processors.py index b214dd77a..63fc12398 100644 --- a/paddlex/inference/models_new/table_structure_recognition/processors.py +++ b/paddlex/inference/models_new/table_structure_recognition/processors.py @@ -22,11 +22,6 @@ class Pad: """Pad the image.""" - INPUT_KEYS = "img" - OUTPUT_KEYS = ["img", "img_size"] - DEAULT_INPUTS = {"img": "img"} - DEAULT_OUTPUTS = {"img": "img", "img_size": "img_size"} - def __init__(self, target_size, val=127.5): """ Initialize the instance. diff --git a/paddlex/inference/models_new/table_structure_recognition/result.py b/paddlex/inference/models_new/table_structure_recognition/result.py index f1aa1300e..3e7c577e3 100644 --- a/paddlex/inference/models_new/table_structure_recognition/result.py +++ b/paddlex/inference/models_new/table_structure_recognition/result.py @@ -51,58 +51,3 @@ def draw_bbox(self, image, boxes): box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64) image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2) return image - - -class StructureTableResult(TableRecResult, HtmlMixin, XlsxMixin): - """StructureTableResult""" - - def __init__(self, data): - super().__init__(data) - HtmlMixin.__init__(self) - XlsxMixin.__init__(self) - - def _to_html(self): - return self["html"] - - -class TableResult(BaseCVResult, HtmlMixin, XlsxMixin): - """TableResult""" - - def __init__(self, data): - super().__init__(data) - HtmlMixin.__init__(self) - XlsxMixin.__init__(self) - - def save_to_html(self, save_path): - if not save_path.lower().endswith(("html")): - input_path = self["input_path"] - save_path = Path(save_path) / f"{Path(input_path).stem}" - else: - save_path = Path(save_path).stem - for table_result in self["table_result"]: - table_result.save_to_html(save_path) - - def save_to_xlsx(self, save_path): - if not save_path.lower().endswith(("xlsx")): - input_path = self["input_path"] - save_path = Path(save_path) / f"{Path(input_path).stem}" - else: - save_path = Path(save_path).stem - for table_result in self["table_result"]: - table_result.save_to_xlsx(save_path) - - def save_to_img(self, save_path): - if not save_path.lower().endswith((".jpg", ".png")): - input_path = self["input_path"] - save_path = Path(save_path) / f"{Path(input_path).stem}" - else: - save_path = Path(save_path).stem - layout_save_path = f"{save_path}_layout.jpg" - ocr_save_path = f"{save_path}_ocr.jpg" - table_save_path = f"{save_path}_table" - layout_result = self["layout_result"] - layout_result.save_to_img(layout_save_path) - ocr_result = self["ocr_result"] - ocr_result.save_to_img(ocr_save_path) - for idx, table_result in enumerate(self["table_result"]): - table_result.save_to_img(f"{table_save_path}_{idx}.jpg") diff --git a/paddlex/inference/pipelines/ppchatocrv3/ppchatocrv3.py b/paddlex/inference/pipelines/ppchatocrv3/ppchatocrv3.py index 0365f5cfd..f62570c11 100644 --- a/paddlex/inference/pipelines/ppchatocrv3/ppchatocrv3.py +++ b/paddlex/inference/pipelines/ppchatocrv3/ppchatocrv3.py @@ -547,7 +547,7 @@ def chat( logging.debug(prompt) res = self.get_llm_result(llm_api, prompt) # TODO: why use one html but the whole table_text in next step - if list(res.values())[0] in failed_results: + if not res or list(res.values())[0] in failed_results: logging.debug( "table html sequence is too much longer, using ocr directly!" ) diff --git a/paddlex/inference/pipelines_new/__init__.py b/paddlex/inference/pipelines_new/__init__.py index 7702aa17f..1170d4620 100644 --- a/paddlex/inference/pipelines_new/__init__.py +++ b/paddlex/inference/pipelines_new/__init__.py @@ -25,11 +25,16 @@ from .image_classification import ImageClassificationPipeline from .seal_recognition import SealRecognitionPipeline from .table_recognition import TableRecognitionPipeline +from .formula_recognition import FormulaRecognitionPipeline from .video_classification import VideoClassificationPipeline from .anomaly_detection import AnomalyDetectionPipeline from .ts_forecasting import TSFcPipeline from .ts_anomaly_detection import TSAnomalyDetPipeline from .ts_classification import TSClsPipeline +from .attribute_recognition import ( + PedestrianAttributeRecPipeline, + VehicleAttributeRecPipeline, +) def get_pipeline_path(pipeline_name: str) -> str: @@ -135,8 +140,8 @@ def create_chat_bot(config: Dict, *args, **kwargs) -> BaseChat: Returns: BaseChat: An instance of the chat bot class corresponding to the 'model_name' in the config. """ - model_name = config["model_name"] - chat_bot = BaseChat.get(model_name)(config) + api_type = config["api_type"] + chat_bot = BaseChat.get(api_type)(config) return chat_bot @@ -156,8 +161,8 @@ def create_retriever( Returns: BaseRetriever: An instance of a retriever class corresponding to the 'model_name' in the config. """ - model_name = config["model_name"] - retriever = BaseRetriever.get(model_name)(config) + api_type = config["api_type"] + retriever = BaseRetriever.get(api_type)(config) return retriever diff --git a/paddlex/inference/pipelines_new/attribute_recognition/__init__.py b/paddlex/inference/pipelines_new/attribute_recognition/__init__.py new file mode 100644 index 000000000..e8bb7826f --- /dev/null +++ b/paddlex/inference/pipelines_new/attribute_recognition/__init__.py @@ -0,0 +1,15 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .pipeline import PedestrianAttributeRecPipeline, VehicleAttributeRecPipeline diff --git a/paddlex/inference/pipelines_new/attribute_recognition/pipeline.py b/paddlex/inference/pipelines_new/attribute_recognition/pipeline.py new file mode 100644 index 000000000..6d6974d94 --- /dev/null +++ b/paddlex/inference/pipelines_new/attribute_recognition/pipeline.py @@ -0,0 +1,100 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional + +import pickle +from pathlib import Path +import numpy as np + +from ...utils.pp_option import PaddlePredictorOption +from ...common.reader import ReadImage +from ...common.batch_sampler import ImageBatchSampler +from ..components import CropByBoxes +from ..base import BasePipeline +from .result import AttributeRecResult + + +class AttributeRecPipeline(BasePipeline): + """Attribute Rec Pipeline""" + + def __init__( + self, + config: Dict, + device: str = None, + pp_option: PaddlePredictorOption = None, + use_hpip: bool = False, + hpi_params: Optional[Dict[str, Any]] = None, + ): + super().__init__( + device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params + ) + + self.det_model = self.create_model(config["SubModules"]["Detection"]) + self.cls_model = self.create_model(config["SubModules"]["Classification"]) + self._crop_by_boxes = CropByBoxes() + self._img_reader = ReadImage(format="BGR") + + self.det_threshold = config["SubModules"]["Detection"].get("threshold", 0.7) + self.cls_threshold = config["SubModules"]["Classification"].get( + "threshold", 0.7 + ) + + self.batch_sampler = ImageBatchSampler( + batch_size=config["SubModules"]["Detection"]["batch_size"] + ) + self.img_reader = ReadImage(format="BGR") + + def predict(self, input, **kwargs): + det_threshold = kwargs.pop("det_threshold", self.det_threshold) + cls_threshold = kwargs.pop("cls_threshold", self.cls_threshold) + for img_id, batch_data in enumerate(self.batch_sampler(input)): + raw_imgs = self.img_reader(batch_data) + all_det_res = list(self.det_model(raw_imgs, threshold=det_threshold)) + for input_data, raw_img, det_res in zip(batch_data, raw_imgs, all_det_res): + cls_res = self.get_cls_result(raw_img, det_res, cls_threshold) + yield self.get_final_result(input_data, raw_img, det_res, cls_res) + + def get_cls_result(self, raw_img, det_res, cls_threshold): + subs_of_img = list(self._crop_by_boxes(raw_img, det_res["boxes"])) + img_list = [img["img"] for img in subs_of_img] + all_cls_res = list(self.cls_model(img_list, threshold=cls_threshold)) + output = {"label": [], "score": []} + for res in all_cls_res: + output["label"].append(res["label_names"]) + output["score"].append(res["scores"]) + return output + + def get_final_result(self, input_data, raw_img, det_res, rec_res): + single_img_res = {"input_path": input_data, "input_img": raw_img, "boxes": []} + for i, obj in enumerate(det_res["boxes"]): + rec_scores = rec_res["score"][i] + labels = rec_res["label"][i] + single_img_res["boxes"].append( + { + "labels": labels, + "rec_scores": rec_scores, + "det_score": obj["score"], + "coordinate": obj["coordinate"], + } + ) + return AttributeRecResult(single_img_res) + + +class PedestrianAttributeRecPipeline(AttributeRecPipeline): + entities = "pedestrian_attribute_recognition" + + +class VehicleAttributeRecPipeline(AttributeRecPipeline): + entities = "vehicle_attribute_recognition" diff --git a/paddlex/inference/pipelines_new/attribute_recognition/result.py b/paddlex/inference/pipelines_new/attribute_recognition/result.py new file mode 100644 index 000000000..2703f7279 --- /dev/null +++ b/paddlex/inference/pipelines_new/attribute_recognition/result.py @@ -0,0 +1,90 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import cv2 +import numpy as np +import PIL +from PIL import Image, ImageDraw, ImageFont + +from ....utils.fonts import PINGFANG_FONT_FILE_PATH +from ...utils.io import ImageReader +from ...common.result import BaseCVResult +from ...utils.color_map import get_colormap, font_colormap + + +def draw_attribute_result(img, boxes): + """ + Args: + img (PIL.Image.Image): PIL image + boxes (list): a list of dictionaries representing detection box information. + Returns: + img (PIL.Image.Image): visualized image + """ + font_size = int((0.024 * int(img.width) + 2) * 0.7) + font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8") + + draw_thickness = int(max(img.size) * 0.005) + draw = ImageDraw.Draw(img) + label2color = {} + catid2fontcolor = {} + color_list = get_colormap(rgb=True) + + for i, dt in enumerate(boxes): + text_lines, bbox, score = dt["label"], dt["coordinate"], dt["score"] + if i not in label2color: + color_index = i % len(color_list) + label2color[i] = color_list[color_index] + catid2fontcolor[i] = font_colormap(color_index) + color = tuple(label2color[i]) + (255,) + font_color = tuple(catid2fontcolor[i]) + + xmin, ymin, xmax, ymax = bbox + # draw box + draw.line( + [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), (xmin, ymin)], + width=draw_thickness, + fill=color, + ) + # draw label + current_y = ymin + for line in text_lines: + if tuple(map(int, PIL.__version__.split("."))) <= (10, 0, 0): + tw, th = draw.textsize(line, font=font) + else: + left, top, right, bottom = draw.textbbox((0, 0), line, font) + tw, th = right - left, bottom - top + 4 + + draw.text((5 + xmin + 1, current_y + 1), line, fill=(0, 0, 0), font=font) + draw.text((5 + xmin, current_y), line, fill=color, font=font) + current_y += th + return img + + +class AttributeRecResult(BaseCVResult): + + def _to_img(self): + """apply""" + img_reader = ImageReader(backend="pillow") + image = img_reader.read(self["input_path"]) + boxes = [ + { + "coordinate": box["coordinate"], + "label": box["labels"], + "score": box["det_score"], + } + for box in self["boxes"] + ] + image = draw_attribute_result(image, boxes) + return image diff --git a/paddlex/inference/pipelines_new/base.py b/paddlex/inference/pipelines_new/base.py index dd96340ec..051b571a5 100644 --- a/paddlex/inference/pipelines_new/base.py +++ b/paddlex/inference/pipelines_new/base.py @@ -67,12 +67,13 @@ def predict(self, input, **kwargs): """ raise NotImplementedError("The method `predict` has not been implemented yet.") - def create_model(self, config: Dict) -> BasePredictor: + def create_model(self, config: Dict, **kwargs) -> BasePredictor: """ Create a model instance based on the given configuration. Args: config (Dict): A dictionary containing configuration settings. + **kwargs: The model arguments that needed to be pass. Returns: BasePredictor: An instance of the model. @@ -82,14 +83,15 @@ def create_model(self, config: Dict) -> BasePredictor: if model_dir == None: model_dir = config["model_name"] - from ...model import create_model + from .. import create_predictor - model = create_model( + model = create_predictor( model=model_dir, device=self.device, pp_option=self.pp_option, use_hpip=self.use_hpip, hpi_params=self.hpi_params, + **kwargs, ) # [TODO] Support initializing with additional parameters diff --git a/paddlex/inference/pipelines_new/components/chat_server/__init__.py b/paddlex/inference/pipelines_new/components/chat_server/__init__.py index a08bf1933..5149d5c15 100644 --- a/paddlex/inference/pipelines_new/components/chat_server/__init__.py +++ b/paddlex/inference/pipelines_new/components/chat_server/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from .ernie_bot_chat import ErnieBotChat +from .openai_bot_chat import OpenAIBotChat diff --git a/paddlex/inference/pipelines_new/components/chat_server/ernie_bot_chat.py b/paddlex/inference/pipelines_new/components/chat_server/ernie_bot_chat.py index c086abb2b..de1d2682a 100644 --- a/paddlex/inference/pipelines_new/components/chat_server/ernie_bot_chat.py +++ b/paddlex/inference/pipelines_new/components/chat_server/ernie_bot_chat.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict import re import json import erniebot +from typing import Dict from .....utils import logging from .base import BaseChat @@ -24,6 +24,11 @@ class ErnieBotChat(BaseChat): """Ernie Bot Chat""" entities = [ + "aistudio", + "qianfan", + ] + + MODELS = [ "ernie-4.0", "ernie-3.5", "ernie-3.5-8k", @@ -53,8 +58,8 @@ def __init__(self, config: Dict) -> None: sk = config.get("sk", None) access_token = config.get("access_token", None) - if model_name not in self.entities: - raise ValueError(f"model_name must be in {self.entities} of ErnieBotChat.") + if model_name not in self.MODELS: + raise ValueError(f"model_name must be in {self.MODELS} of ErnieBotChat.") if api_type not in ["aistudio", "qianfan"]: raise ValueError("api_type must be one of ['aistudio', 'qianfan']") @@ -127,6 +132,12 @@ def fix_llm_result_format(self, llm_result: str) -> dict: return {} if "json" in llm_result or "```" in llm_result: + index = llm_result.find("{") + if index != -1: + llm_result = llm_result[index:] + index = llm_result.rfind("}") + if index != -1: + llm_result = llm_result[: index + 1] llm_result = ( llm_result.replace("```", "").replace("json", "").replace("/n", "") ) @@ -135,6 +146,15 @@ def fix_llm_result_format(self, llm_result: str) -> dict: try: llm_result = json.loads(llm_result) llm_result_final = {} + if "问题" in llm_result.keys() and "答案" in llm_result.keys(): + key = llm_result["问题"] + value = llm_result["答案"] + if isinstance(value, list): + if len(value) > 0: + llm_result_final[key] = value[0].strip(f"{key}:").strip(key) + else: + llm_result_final[key] = value.strip(f"{key}:").strip(key) + return llm_result_final for key in llm_result: value = llm_result[key] if isinstance(value, list): @@ -157,6 +177,16 @@ def fix_llm_result_format(self, llm_result: str) -> dict: matches = re.findall(pattern, str(results)) if len(matches) > 0: llm_result = {k: v for k, v in matches} + if "问题" in llm_result.keys() and "答案" in llm_result.keys(): + llm_result_final = {} + key = llm_result["问题"] + value = llm_result["答案"] + if isinstance(value, list): + if len(value) > 0: + llm_result_final[key] = value[0].strip(f"{key}:").strip(key) + else: + llm_result_final[key] = value.strip(f"{key}:").strip(key) + return llm_result_final return llm_result else: return {} diff --git a/paddlex/inference/pipelines_new/components/chat_server/openai_bot_chat.py b/paddlex/inference/pipelines_new/components/chat_server/openai_bot_chat.py new file mode 100644 index 000000000..16c0a9466 --- /dev/null +++ b/paddlex/inference/pipelines_new/components/chat_server/openai_bot_chat.py @@ -0,0 +1,204 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import json +import base64 +from typing import Dict +from .....utils import logging +from .base import BaseChat + + +class OpenAIBotChat(BaseChat): + """OpenAI Bot Chat""" + + entities = [ + "openai", + ] + + def __init__(self, config: Dict) -> None: + """Initializes the OpenAIBotChat with given configuration. + + Args: + config (Dict): Configuration dictionary containing model_name, api_type, base_url, api_key. + + Raises: + ValueError: If api_type is not one of ['openai'], + base_url is None for api_type is openai, + api_key is None for api_type is openai. + """ + super().__init__() + model_name = config.get("model_name", None) + api_type = config.get("api_type", None) + api_key = config.get("api_key", None) + base_url = config.get("base_url", None) + + if api_type not in ["openai"]: + raise ValueError("api_type must be one of ['openai']") + + if api_type == "openai" and api_key is None: + raise ValueError("api_key cannot be empty when api_type is openai.") + + if base_url is None: + raise ValueError("base_url cannot be empty when api_type is openai.") + + try: + from openai import OpenAI + except: + raise Exception("openai is not installed, please install it first.") + + self.client = OpenAI(base_url=base_url, api_key=api_key) + + self.model_name = model_name + self.config = config + + def generate_chat_results( + self, + prompt: str, + image: base64 = None, + temperature: float = 0.001, + max_retries: int = 1, + ) -> Dict: + """ + Generate chat results using the specified model and configuration. + + Args: + prompt (str): The user's input prompt. + image (base64): The user's input image for MLLM, defaults to None. + temperature (float, optional): The temperature parameter for llms, defaults to 0.001. + max_retries (int, optional): The maximum number of retries for llms API calls, defaults to 1. + + Returns: + Dict: The chat completion result from the model. + """ + try: + if image: + chat_completion = self.client.chat.completions.create( + model=self.model_name, + messages=[ + { + "role": "system", + # XXX: give a basic prompt for common + "content": "You are a helpful assistant.", + }, + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image}" + }, + }, + ], + }, + ], + stream=False, + temperature=temperature, + top_p=0.001, + ) + llm_result = chat_completion.choices[0].message.content + return llm_result + else: + chat_completion = self.client.completions.create( + model=self.model_name, + prompt=prompt, + max_tokens=self.config.get("max_tokens", 1024), + temperature=float(temperature), + stream=False, + ) + if isinstance(chat_completion, str): + chat_completion = json.loads(chat_completion) + llm_result = chat_completion["choices"][0]["text"] + else: + llm_result = chat_completion.choices[0].text + return llm_result + except Exception as e: + logging.error(e) + self.ERROR_MASSAGE = "大模型调用失败" + return None + + def fix_llm_result_format(self, llm_result: str) -> dict: + """ + Fix the format of the LLM result. + + Args: + llm_result (str): The result from the LLM (Large Language Model). + + Returns: + dict: A fixed format dictionary from the LLM result. + """ + if not llm_result: + return {} + + if "json" in llm_result or "```" in llm_result: + index = llm_result.find("{") + if index != -1: + llm_result = llm_result[index:] + index = llm_result.rfind("}") + if index != -1: + llm_result = llm_result[: index + 1] + llm_result = ( + llm_result.replace("```", "").replace("json", "").replace("/n", "") + ) + llm_result = llm_result.replace("[", "").replace("]", "") + + try: + llm_result = json.loads(llm_result) + llm_result_final = {} + if "问题" in llm_result.keys() and "答案" in llm_result.keys(): + key = llm_result["问题"] + value = llm_result["答案"] + if isinstance(value, list): + if len(value) > 0: + llm_result_final[key] = value[0].strip(f"{key}:").strip(key) + else: + llm_result_final[key] = value.strip(f"{key}:").strip(key) + return llm_result_final + for key in llm_result: + value = llm_result[key] + if isinstance(value, list): + if len(value) > 0: + llm_result_final[key] = value[0] + else: + llm_result_final[key] = value + return llm_result_final + + except: + results = ( + llm_result.replace("\n", "") + .replace(" ", "") + .replace("{", "") + .replace("}", "") + ) + if not results.endswith('"'): + results = results + '"' + pattern = r'"(.*?)": "([^"]*)"' + matches = re.findall(pattern, str(results)) + if len(matches) > 0: + llm_result = {k: v for k, v in matches} + if "问题" in llm_result.keys() and "答案" in llm_result.keys(): + llm_result_final = {} + key = llm_result["问题"] + value = llm_result["答案"] + if isinstance(value, list): + if len(value) > 0: + llm_result_final[key] = value[0].strip(f"{key}:").strip(key) + else: + llm_result_final[key] = value.strip(f"{key}:").strip(key) + return llm_result_final + return llm_result + else: + return {} diff --git a/paddlex/inference/pipelines_new/components/retriever/__init__.py b/paddlex/inference/pipelines_new/components/retriever/__init__.py index f829efe0c..8d3f52e25 100644 --- a/paddlex/inference/pipelines_new/components/retriever/__init__.py +++ b/paddlex/inference/pipelines_new/components/retriever/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from .ernie_bot_retriever import ErnieBotRetriever +from .openai_bot_retriever import OpenAIBotRetriever diff --git a/paddlex/inference/pipelines_new/components/retriever/ernie_bot_retriever.py b/paddlex/inference/pipelines_new/components/retriever/ernie_bot_retriever.py index 0b299ce50..d3a3c94f6 100644 --- a/paddlex/inference/pipelines_new/components/retriever/ernie_bot_retriever.py +++ b/paddlex/inference/pipelines_new/components/retriever/ernie_bot_retriever.py @@ -28,6 +28,11 @@ class ErnieBotRetriever(BaseRetriever): """Ernie Bot Retriever""" entities = [ + "aistudio", + "qianfan", + ] + + MODELS = [ "ernie-4.0", "ernie-3.5", "ernie-3.5-8k", @@ -45,7 +50,7 @@ def __init__(self, config: Dict) -> None: Args: config (Dict): A dictionary containing configuration settings. - model_name (str): The name of the model to use. - - api_type (str): The type of API to use ('aistudio' or 'qianfan'). + - api_type (str): The type of API to use ('aistudio', 'qianfan' or 'openai'). - ak (str, optional): The access key for 'qianfan' API. - sk (str, optional): The secret key for 'qianfan' API. - access_token (str, optional): The access token for 'aistudio' API. @@ -64,8 +69,8 @@ def __init__(self, config: Dict) -> None: sk = config.get("sk", None) access_token = config.get("access_token", None) - if model_name not in self.entities: - raise ValueError(f"model_name must be in {self.entities} of ErnieBotChat.") + if model_name not in self.MODELS: + raise ValueError(f"model_name must be in {self.MODELS} of ErnieBotChat.") if api_type not in ["aistudio", "qianfan"]: raise ValueError("api_type must be one of ['aistudio', 'qianfan']") diff --git a/paddlex/inference/pipelines_new/components/retriever/openai_bot_retriever.py b/paddlex/inference/pipelines_new/components/retriever/openai_bot_retriever.py new file mode 100644 index 000000000..69ff797f9 --- /dev/null +++ b/paddlex/inference/pipelines_new/components/retriever/openai_bot_retriever.py @@ -0,0 +1,181 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import BaseRetriever + +from langchain.docstore.document import Document +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_community.vectorstores import FAISS +from langchain_community import vectorstores + +import time + +from typing import Dict + + +class OpenAIBotRetriever(BaseRetriever): + """OpenAI Bot Retriever""" + + entities = [ + "openai", + ] + + def __init__(self, config: Dict) -> None: + """ + Initializes the OpenAIBotRetriever instance with the provided configuration. + + Args: + config (Dict): A dictionary containing configuration settings. + - model_name (str): The name of the model to use. + - api_type (str): The type of API to use ('aistudio', 'qianfan' or 'openai'). + - api_key (str, optional): The API key for 'openai' API. + - base_url (str, optional): The base URL for 'openai' API. + + Raises: + ValueError: If api_type is not one of ['openai'], + base_url is None for api_type is openai, + api_key is None for api_type is openai. + """ + super().__init__() + + model_name = config.get("model_name", None) + api_type = config.get("api_type", None) + api_key = config.get("api_key", None) + base_url = config.get("base_url", None) + tiktoken_enabled = config.get("tiktoken_enabled", False) + + if api_type not in ["openai"]: + raise ValueError("api_type must be one of ['openai']") + + if api_type == "openai" and api_key is None: + raise ValueError("api_key cannot be empty when api_type is openai.") + + if base_url is None: + raise ValueError("base_url cannot be empty when api_type is openai.") + + try: + from langchain_openai import OpenAIEmbeddings + except: + raise Exception( + "langchain-openai is not installed, please install it first." + ) + + self.embedding = OpenAIEmbeddings( + model=model_name, + api_key=api_key, + base_url=base_url, + tiktoken_enabled=tiktoken_enabled, + ) + + self.model_name = model_name + self.config = config + + # Generates a vector database from a list of texts using different embeddings based on the configured API type. + + def generate_vector_database( + self, + text_list: list[str], + block_size: int = 300, + separators: list[str] = ["\t", "\n", "。", "\n\n", ""], + sleep_time: float = 0.5, + ) -> FAISS: + """ + Generates a vector database from a list of texts. + + Args: + text_list (list[str]): A list of texts to generate the vector database from. + block_size (int): The size of each chunk to split the text into. + separators (list[str]): A list of separators to use when splitting the text. + sleep_time (float): The time to sleep between embedding generations to avoid rate limiting. + + Returns: + FAISS: The generated vector database. + + Raises: + ValueError: If an unsupported API type is configured. + """ + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=block_size, chunk_overlap=20, separators=separators + ) + texts = text_splitter.split_text("\t".join(text_list)) + all_splits = [Document(page_content=text) for text in texts] + + api_type = self.config["api_type"] + + vectorstore = FAISS.from_documents( + documents=all_splits, embedding=self.embedding + ) + + return vectorstore + + def encode_vector_store_to_bytes(self, vectorstore: FAISS) -> str: + """ + Encode the vector store serialized to bytes. + + Args: + vectorstore (FAISS): The vector store to be serialized and encoded. + + Returns: + str: The encoded vector store. + """ + vectorstore = self.encode_vector_store(vectorstore.serialize_to_bytes()) + return vectorstore + + def decode_vector_store_from_bytes(self, vectorstore: str) -> FAISS: + """ + Decode a vector store from bytes according to the specified API type. + + Args: + vectorstore (str): The serialized vector store string. + + Returns: + FAISS: Deserialized vector store object. + + Raises: + ValueError: If the retrieved vector store is not for PaddleX + or if an unsupported API type is specified. + """ + if not self.is_vector_store(vectorstore): + raise ValueError("The retrieved vectorstore is not for PaddleX.") + + vector = vectorstores.FAISS.deserialize_from_bytes( + self.decode_vector_store(vectorstore), self.embedding + ) + return vector + + def similarity_retrieval( + self, query_text_list: list[str], vectorstore: FAISS, sleep_time: float = 0.5 + ) -> str: + """ + Retrieve similar contexts based on a list of query texts. + + Args: + query_text_list (list[str]): A list of query texts to search for similar contexts. + vectorstore (FAISS): The vector store where to perform the similarity search. + sleep_time (float): The time to sleep between each query, in seconds. Default is 0.5. + + Returns: + str: A concatenated string of all unique contexts found. + """ + C = [] + for query_text in query_text_list: + QUESTION = query_text + time.sleep(sleep_time) + docs = vectorstore.similarity_search_with_relevance_scores(QUESTION, k=2) + context = [(document.page_content, score) for document, score in docs] + context = sorted(context, key=lambda x: x[1]) + C.extend([x[0] for x in context[::-1]]) + C = list(set(C)) + all_C = " ".join(C) + return all_C diff --git a/paddlex/inference/pipelines_new/formula_recognition/__init__.py b/paddlex/inference/pipelines_new/formula_recognition/__init__.py new file mode 100644 index 000000000..655bbedea --- /dev/null +++ b/paddlex/inference/pipelines_new/formula_recognition/__init__.py @@ -0,0 +1,15 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .pipeline import FormulaRecognitionPipeline diff --git a/paddlex/inference/pipelines_new/formula_recognition/pipeline.py b/paddlex/inference/pipelines_new/formula_recognition/pipeline.py new file mode 100644 index 000000000..8eb0ded67 --- /dev/null +++ b/paddlex/inference/pipelines_new/formula_recognition/pipeline.py @@ -0,0 +1,259 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os, sys +from typing import Any, Dict, Optional +import numpy as np +import cv2 +from ..base import BasePipeline +from ..components import CropByBoxes +from ..layout_parsing.utils import convert_points_to_boxes + +from .result import FormulaRecognitionResult +from ...models_new.formula_recognition.result import ( + FormulaRecResult as SingleFormulaRecognitionResult, +) +from ....utils import logging +from ...utils.pp_option import PaddlePredictorOption +from ...common.reader import ReadImage +from ...common.batch_sampler import ImageBatchSampler +from ..ocr.result import OCRResult +from ..doc_preprocessor.result import DocPreprocessorResult + +# [TODO] 待更新models_new到models +from ...models_new.object_detection.result import DetResult + + +class FormulaRecognitionPipeline(BasePipeline): + """Formula Recognition Pipeline""" + + entities = ["formula_recognition"] + + def __init__( + self, + config: Dict, + device: str = None, + pp_option: PaddlePredictorOption = None, + use_hpip: bool = False, + hpi_params: Optional[Dict[str, Any]] = None, + ) -> None: + """Initializes the layout parsing pipeline. + + Args: + config (Dict): Configuration dictionary containing various settings. + device (str, optional): Device to run the predictions on. Defaults to None. + pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None. + use_hpip (bool, optional): Whether to use high-performance inference (hpip) for prediction. Defaults to False. + hpi_params (Optional[Dict[str, Any]], optional): HPIP parameters. Defaults to None. + """ + + super().__init__( + device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params + ) + + self.use_doc_preprocessor = False + if "use_doc_preprocessor" in config: + self.use_doc_preprocessor = config["use_doc_preprocessor"] + + if self.use_doc_preprocessor: + doc_preprocessor_config = config["SubPipelines"]["DocPreprocessor"] + self.doc_preprocessor_pipeline = self.create_pipeline( + doc_preprocessor_config + ) + + self.use_layout_detection = True + if "use_layout_detection" in config: + self.use_layout_detection = config["use_layout_detection"] + if self.use_layout_detection: + layout_det_config = config["SubModules"]["LayoutDetection"] + self.layout_det_model = self.create_model(layout_det_config) + + formula_recognition_config = config["SubModules"]["FormulaRecognition"] + self.formula_recognition_model = self.create_model(formula_recognition_config) + + self._crop_by_boxes = CropByBoxes() + + self.batch_sampler = ImageBatchSampler(batch_size=1) + self.img_reader = ReadImage(format="BGR") + + def check_input_params_valid( + self, input_params: Dict, layout_det_res: DetResult + ) -> bool: + """ + Check if the input parameters are valid based on the initialized models. + + Args: + input_params (Dict): A dictionary containing input parameters. + layout_det_res (DetResult): The layout detection result. + Returns: + bool: True if all required models are initialized according to input parameters, False otherwise. + """ + + if input_params["use_doc_preprocessor"] and not self.use_doc_preprocessor: + logging.error( + "Set use_doc_preprocessor, but the models for doc preprocessor are not initialized." + ) + return False + + if input_params["use_layout_detection"]: + if layout_det_res is not None: + logging.error( + "The layout detection model has already been initialized, please set use_layout_detection=False" + ) + return False + + if not self.use_layout_detection: + logging.error( + "Set use_layout_detection, but the models for layout detection are not initialized." + ) + return False + + return True + + def predict_doc_preprocessor_res( + self, image_array: np.ndarray, input_params: dict + ) -> tuple[DocPreprocessorResult, np.ndarray]: + """ + Preprocess the document image based on input parameters. + + Args: + image_array (np.ndarray): The input image array. + input_params (dict): Dictionary containing preprocessing parameters. + + Returns: + tuple[DocPreprocessorResult, np.ndarray]: A tuple containing the preprocessing + result dictionary and the processed image array. + """ + if input_params["use_doc_preprocessor"]: + use_doc_orientation_classify = input_params["use_doc_orientation_classify"] + use_doc_unwarping = input_params["use_doc_unwarping"] + doc_preprocessor_res = next( + self.doc_preprocessor_pipeline( + image_array, + use_doc_orientation_classify=use_doc_orientation_classify, + use_doc_unwarping=use_doc_unwarping, + ) + ) + doc_preprocessor_image = doc_preprocessor_res["output_img"] + else: + doc_preprocessor_res = {} + doc_preprocessor_image = image_array + return doc_preprocessor_res, doc_preprocessor_image + + def predict_single_formula_recognition_res( + self, + image_array: np.ndarray, + ) -> SingleFormulaRecognitionResult: + """ + Predict formula recognition results from an image array, layout detection results. + + Args: + image_array (np.ndarray): The input image represented as a numpy array. + formula_box (list): The formula box coordinates. + flag_find_nei_text (bool): Whether to find neighboring text. + Returns: + SingleFormulaRecognitionResult: single formula recognition result. + """ + + formula_recognition_pred = next(self.formula_recognition_model(image_array)) + + return formula_recognition_pred + + def predict( + self, + input: str | list[str] | np.ndarray | list[np.ndarray], + use_layout_detection: bool = True, + use_doc_orientation_classify: bool = False, + use_doc_unwarping: bool = False, + layout_det_res: DetResult = None, + **kwargs + ) -> FormulaRecognitionResult: + """ + This function predicts the layout parsing result for the given input. + + Args: + input (str | list[str] | np.ndarray | list[np.ndarray]): The input image(s) of pdf(s) to be processed. + use_layout_detection (bool): Whether to use layout detection. + use_doc_orientation_classify (bool): Whether to use document orientation classification. + use_doc_unwarping (bool): Whether to use document unwarping. + layout_det_res (DetResult): The layout detection result. + It will be used if it is not None and use_layout_detection is False. + **kwargs: Additional keyword arguments. + + Returns: + formulaRecognitionResult: The predicted formula recognition result. + """ + + input_params = { + "use_layout_detection": use_layout_detection, + "use_doc_preprocessor": self.use_doc_preprocessor, + "use_doc_orientation_classify": use_doc_orientation_classify, + "use_doc_unwarping": use_doc_unwarping, + } + + if use_doc_orientation_classify or use_doc_unwarping: + input_params["use_doc_preprocessor"] = True + else: + input_params["use_doc_preprocessor"] = False + + if not self.check_input_params_valid(input_params, layout_det_res): + yield None + + for img_id, batch_data in enumerate(self.batch_sampler(input)): + image_array = self.img_reader(batch_data)[0] + input_path = batch_data[0] + img_id += 1 + + doc_preprocessor_res, doc_preprocessor_image = ( + self.predict_doc_preprocessor_res(image_array, input_params) + ) + + formula_res_list = [] + formula_region_id = 1 + + if not input_params["use_layout_detection"] and layout_det_res is None: + layout_det_res = {} + img_height, img_width = doc_preprocessor_image.shape[:2] + single_formula_rec_res = self.predict_single_formula_recognition_res( + doc_preprocessor_image, + ) + single_formula_rec_res["formula_region_id"] = formula_region_id + formula_res_list.append(single_formula_rec_res) + formula_region_id += 1 + else: + if input_params["use_layout_detection"]: + layout_det_res = next(self.layout_det_model(doc_preprocessor_image)) + for box_info in layout_det_res["boxes"]: + if box_info["label"].lower() in ["formula"]: + crop_img_info = self._crop_by_boxes(image_array, [box_info]) + crop_img_info = crop_img_info[0] + single_formula_rec_res = ( + self.predict_single_formula_recognition_res( + crop_img_info["img"] + ) + ) + single_formula_rec_res["formula_region_id"] = formula_region_id + single_formula_rec_res["dt_polys"] = box_info["coordinate"] + formula_res_list.append(single_formula_rec_res) + formula_region_id += 1 + + single_img_res = { + "layout_det_res": layout_det_res, + "doc_preprocessor_res": doc_preprocessor_res, + "formula_res_list": formula_res_list, + "input_params": input_params, + "img_id": img_id, + "img_name": input_path, + } + yield FormulaRecognitionResult(single_img_res) diff --git a/paddlex/inference/pipelines_new/formula_recognition/result.py b/paddlex/inference/pipelines_new/formula_recognition/result.py new file mode 100644 index 000000000..a816cd20a --- /dev/null +++ b/paddlex/inference/pipelines_new/formula_recognition/result.py @@ -0,0 +1,216 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os, sys +from typing import Tuple +import cv2 +import PIL +import math +import random +import tempfile +import subprocess +import numpy as np +from pathlib import Path +from PIL import Image, ImageDraw, ImageFont + +from ...common.result import BaseCVResult +from ....utils import logging +from ....utils.fonts import PINGFANG_FONT_FILE_PATH +from ...models_new.formula_recognition.result import ( + get_align_equation, + generate_tex_file, + generate_pdf_file, + env_valid, + pdf2img, + create_font, + crop_white_area, + draw_box_txt_fine, +) + + +class FormulaRecognitionResult(dict): + """Layout Parsing Result""" + + def __init__(self, data) -> None: + """Initializes a new instance of the class with the specified data.""" + super().__init__(data) + + def save_to_img(self, save_path: str) -> None: + """ + Saves an image with overlaid formula recognition results. + + This function attempts to save an image with recognized formulas highlighted + and annotated. It verifies the environment setup before proceeding and logs + a warning if the necessary rendering engine is not installed. The output image + consists of two halves: the left side shows the original image with bounding + boxes, and the right side shows the recognized formulas. + + Args: + save_path (str): The directory path where the output image will be saved. + + Returns: + None + """ + try: + env_valid() + except subprocess.CalledProcessError as e: + logging.warning( + "Please refer to 2.3 Formula Recognition Pipeline Visualization in Formula Recognition Pipeline Tutorial to install the LaTeX rendering engine at first." + ) + return None + if not os.path.exists(save_path): + os.makedirs(save_path) + img_id = self["img_id"] + img_name = self["img_name"] + if len(self["layout_det_res"]) <= 0: + return + image = Image.fromarray(self["layout_det_res"]["input_img"]) + h, w = image.height, image.width + img_left = image.copy() + img_right = np.ones((h, w, 3), dtype=np.uint8) * 255 + random.seed(0) + draw_left = ImageDraw.Draw(img_left) + + formula_save_path = os.path.join(save_path, "formula_img_{}.jpg".format(img_id)) + formula_res_list = self["formula_res_list"] + for tno in range(len(self["formula_res_list"])): + formula_res = self["formula_res_list"][tno] + formula_region_id = formula_res["formula_region_id"] + formula = str(formula_res["rec_formula"]) + dt_polys = formula_res["dt_polys"] + x1, y1, x2, y2 = list(dt_polys) + try: + color = ( + random.randint(0, 255), + random.randint(0, 255), + random.randint(0, 255), + ) + box = [x1, y1, x2, y1, x2, y2, x1, y2] + box = np.array(box).reshape([-1, 2]) + pts = [(x, y) for x, y in box.tolist()] + draw_left.polygon(pts, outline=color, width=8) + draw_left.polygon(box, fill=color) + img_right_text = draw_box_formula_fine( + (w, h), + box, + formula, + is_debug=False, + ) + pts = np.array(box, np.int32).reshape((-1, 1, 2)) + cv2.polylines(img_right_text, [pts], True, color, 1) + img_right = cv2.bitwise_and(img_right, img_right_text) + except subprocess.CalledProcessError as e: + logging.warning("Syntax error detected in formula, rendering failed.") + continue + img_left = Image.blend(image, img_left, 0.5) + img_show = Image.new("RGB", (int(w * 2), h), (255, 255, 255)) + img_show.paste(img_left, (0, 0, w, h)) + img_show.paste(Image.fromarray(img_right), (w, 0, w * 2, h)) + img_show.save(formula_save_path) + + def save_results(self, save_path: str) -> None: + """Save the formula recognition results to the specified directory. + + Args: + save_path (str): The directory path to save the results. + """ + if not os.path.exists(save_path): + os.makedirs(save_path) + if not os.path.isdir(save_path): + return + + img_id = self["img_id"] + layout_det_res = self["layout_det_res"] + if len(layout_det_res) > 0: + save_img_path = Path(save_path) / f"layout_det_result_img{img_id}.jpg" + layout_det_res.save_to_img(save_img_path) + self.save_to_img(save_path) + input_params = self["input_params"] + if input_params["use_doc_preprocessor"]: + save_img_path = Path(save_path) / f"doc_preprocessor_result_img{img_id}.jpg" + self["doc_preprocessor_res"].save_to_img(save_img_path) + for tno in range(len(self["formula_res_list"])): + formula_res = self["formula_res_list"][tno] + formula_region_id = formula_res["formula_region_id"] + save_img_path = ( + Path(save_path) + / f"formula_res_img{img_id}_region{formula_region_id}.jpg" + ) + formula_res.save_to_img(save_img_path) + return + + +def draw_box_formula_fine( + img_size: Tuple[int, int], box: np.ndarray, formula: str, is_debug: bool = False +) -> np.ndarray: + """draw box formula for pipeline""" + """ + Draw box formula for pipeline. + + This function generates a LaTeX formula image and transforms it to fit + within a specified bounding box on a larger image. If the rendering fails, + it will write "Rendering Failed" inside the box. + + Args: + img_size (Tuple[int, int]): The size of the image (width, height). + box (np.ndarray): A numpy array representing the four corners of the bounding box. + formula (str): The LaTeX formula to render. + is_debug (bool, optional): If True, enables debug mode. Defaults to False. + + Returns: + np.ndarray: An image array with the rendered formula inside the specified box. + """ + box_height = int( + math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2) + ) + box_width = int( + math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2) + ) + with tempfile.TemporaryDirectory() as td: + tex_file_path = os.path.join(td, "temp.tex") + pdf_file_path = os.path.join(td, "temp.pdf") + img_file_path = os.path.join(td, "temp.jpg") + generate_tex_file(tex_file_path, formula) + if os.path.exists(tex_file_path): + generate_pdf_file(tex_file_path, td, is_debug) + formula_img = None + if os.path.exists(pdf_file_path): + formula_img = pdf2img(pdf_file_path, img_file_path, is_padding=False) + if formula_img is not None: + formula_h, formula_w = formula_img.shape[:-1] + resize_height = box_height + resize_width = formula_w * resize_height / formula_h + formula_img = cv2.resize( + formula_img, (int(resize_width), int(resize_height)) + ) + formula_h, formula_w = formula_img.shape[:-1] + pts1 = np.float32( + [[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]] + ) + pts2 = np.array(box, dtype=np.float32) + M = cv2.getPerspectiveTransform(pts1, pts2) + formula_img = np.array(formula_img, dtype=np.uint8) + img_right_text = cv2.warpPerspective( + formula_img, + M, + img_size, + flags=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_CONSTANT, + borderValue=(255, 255, 255), + ) + else: + img_right_text = draw_box_txt_fine( + img_size, box, "Rendering Failed", PINGFANG_FONT_FILE_PATH + ) + return img_right_text diff --git a/paddlex/inference/pipelines_new/image_classification/pipeline.py b/paddlex/inference/pipelines_new/image_classification/pipeline.py index 062d71472..7e593e662 100644 --- a/paddlex/inference/pipelines_new/image_classification/pipeline.py +++ b/paddlex/inference/pipelines_new/image_classification/pipeline.py @@ -51,13 +51,16 @@ def __init__( ) image_classification_model_config = config["SubModules"]["ImageClassification"] + model_kwargs = {} + if (topk := image_classification_model_config.get("topk", None)) is not None: + model_kwargs = {"topk": topk} self.image_classification_model = self.create_model( - image_classification_model_config + image_classification_model_config, **model_kwargs ) - self.topk = image_classification_model_config["topk"] + self.topk = image_classification_model_config.get("topk", 5) def predict( - self, input: str | list[str] | np.ndarray | list[np.ndarray], **kwargs + self, input: str | list[str] | np.ndarray | list[np.ndarray], topk=None ) -> TopkResult: """Predicts image classification results for the given input. @@ -68,4 +71,6 @@ def predict( Returns: TopkResult: The predicted top k results. """ - yield from self.image_classification_model(input, topk=self.topk) + + topk = kwargs.pop("topk", self.topk) + yield from self.image_classification_model(input, topk=topk) diff --git a/paddlex/inference/pipelines_new/ocr/pipeline.py b/paddlex/inference/pipelines_new/ocr/pipeline.py index ffd2761d5..a1ec6359a 100644 --- a/paddlex/inference/pipelines_new/ocr/pipeline.py +++ b/paddlex/inference/pipelines_new/ocr/pipeline.py @@ -33,8 +33,18 @@ class OCRPipeline(BasePipeline): def __init__( self, config: Dict, - device: str = None, - pp_option: PaddlePredictorOption = None, + device: Optional[str] = None, + use_doc_orientation_classify: Optional[bool] = None, + use_doc_unwarping: Optional[bool] = None, + use_textline_orientation: Optional[bool] = None, + limit_side_len: Optional[int] = None, + limit_type: Optional[str] = None, + thresh: Optional[float] = None, + box_thresh: Optional[float] = None, + max_candidates: Optional[int] = None, + unclip_ratio: Optional[float] = None, + use_dilation: Optional[bool] = None, + pp_option: Optional[PaddlePredictorOption] = None, use_hpip: bool = False, hpi_params: Optional[Dict[str, Any]] = None, ) -> None: @@ -43,16 +53,65 @@ def __init__( Args: config (Dict): Configuration dictionary containing model and other parameters. - device (str): The device to run the prediction on. Default is None. - pp_option (PaddlePredictorOption): Options for PaddlePaddle predictor. Default is None. - use_hpip (bool): Whether to use high-performance inference (hpip) for prediction. Defaults to False. - hpi_params (Optional[Dict[str, Any]]): HPIP specific parameters. Default is None. + device (Union[str, None]): The device to run the prediction on. + use_textline_orientation (Union[bool, None]): Whether to use textline orientation. + use_doc_orientation_classify (Union[bool, None]): Whether to use document orientation classification. + use_doc_unwarping (Union[bool, None]): Whether to use document unwarping. + limit_side_len (Union[int, None]): Limit of side length. + limit_type (Union[str, None]): Type of limit. + thresh (Union[float, None]): Threshold value. + box_thresh (Union[float, None]): Box threshold value. + max_candidates (Union[int, None]): Maximum number of candidates. + unclip_ratio (Union[float, None]): Unclip ratio. + use_dilation (Union[bool, None]): Whether to use dilation. + pp_option (Union[PaddlePredictorOption, None]): Options for PaddlePaddle predictor. + use_hpip (Union[bool, None]): Whether to use high-performance inference. + hpi_params (Union[Dict[str, Any], None]): HPIP specific parameters. """ super().__init__( device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params ) - self.inintial_predictor(config) + self.use_textline_orientation = ( + use_textline_orientation + if use_textline_orientation is not None + else config.get("use_textline_orientation", False) + ) + self.use_doc_preprocessor = self.get_preprocessor_value( + use_doc_orientation_classify, use_doc_unwarping, config, False + ) + + text_det_default_params = { + "limit_side_len": 960, + "limit_type": "max", + "thresh": 0.3, + "box_thresh": 0.6, + "max_candidates": 1000, + "unclip_ratio": 2.0, + "use_dilation": False, + } + + text_det_config = config["SubModules"]["TextDetection"] + for key, default_params in text_det_default_params.items(): + text_det_config[key] = locals().get( + key, text_det_config.get(key, default_params) + ) + self.text_det_model = self.create_model(text_det_config) + + text_rec_config = config["SubModules"]["TextRecognition"] + self.text_rec_model = self.create_model(text_rec_config) + + if self.use_textline_orientation: + textline_orientation_config = config["SubModules"]["TextLineOrientation"] + self.textline_orientation_model = self.create_model( + textline_orientation_config + ) + + if self.use_doc_preprocessor: + doc_preprocessor_config = config["SubPipelines"]["DocPreprocessor"] + self.doc_preprocessor_pipeline = self.create_pipeline( + doc_preprocessor_config + ) self.text_type = config["text_type"] @@ -68,60 +127,15 @@ def __init__( self.batch_sampler = ImageBatchSampler(batch_size=1) self.img_reader = ReadImage(format="BGR") - def set_used_models_flag(self, config: Dict) -> None: - """ - Set the flags for which models to use based on the configuration. - - Args: - config (Dict): A dictionary containing configuration settings. - - Returns: - None - """ - pipeline_name = config["pipeline_name"] - - self.pipeline_name = pipeline_name - - self.use_doc_preprocessor = False - - if "use_doc_preprocessor" in config: - self.use_doc_preprocessor = config["use_doc_preprocessor"] - - self.use_textline_orientation = False - - if "use_textline_orientation" in config: - self.use_textline_orientation = config["use_textline_orientation"] - - def inintial_predictor(self, config: Dict) -> None: - """Initializes the predictor based on the provided configuration. - - Args: - config (Dict): A dictionary containing the configuration for the predictor. - - Returns: - None - """ - - self.set_used_models_flag(config) - - text_det_model_config = config["SubModules"]["TextDetection"] - self.text_det_model = self.create_model(text_det_model_config) - - text_rec_model_config = config["SubModules"]["TextRecognition"] - self.text_rec_model = self.create_model(text_rec_model_config) - - if self.use_doc_preprocessor: - doc_preprocessor_config = config["SubPipelines"]["DocPreprocessor"] - self.doc_preprocessor_pipeline = self.create_pipeline( - doc_preprocessor_config - ) - # Just for initialize the predictor - if self.use_textline_orientation: - textline_orientation_config = config["SubModules"]["TextLineOrientation"] - self.textline_orientation_model = self.create_model( - textline_orientation_config - ) - return + @staticmethod + def get_preprocessor_value(orientation, unwarping, config, default): + if orientation is None and unwarping is None: + return config.get("use_doc_preprocessor", default) + else: + if orientation is False and unwarping is False: + return False + else: + return True def rotate_image( self, image_array_list: List[np.ndarray], rotate_angle_list: List[int] @@ -159,25 +173,25 @@ def rotate_image( return rotated_images - def check_input_params_valid(self, input_params: Dict) -> bool: + def check_model_settings_valid(self, model_settings: Dict) -> bool: """ Check if the input parameters are valid based on the initialized models. Args: - input_params (Dict): A dictionary containing input parameters. + model_info_params(Dict): A dictionary containing input parameters. Returns: bool: True if all required models are initialized according to input parameters, False otherwise. """ - if input_params["use_doc_preprocessor"] and not self.use_doc_preprocessor: + if model_settings["use_doc_preprocessor"] and not self.use_doc_preprocessor: logging.error( "Set use_doc_preprocessor, but the models for doc preprocessor are not initialized." ) return False if ( - input_params["use_textline_orientation"] + model_settings["use_textline_orientation"] and not self.use_textline_orientation ): logging.error( @@ -211,11 +225,9 @@ def predict_doc_preprocessor_res( use_doc_unwarping=use_doc_unwarping, ) ) - doc_preprocessor_image = doc_preprocessor_res["output_img"] else: - doc_preprocessor_res = {} - doc_preprocessor_image = image_array - return doc_preprocessor_res, doc_preprocessor_image + doc_preprocessor_res = {"output_img": image_array} + return doc_preprocessor_res def predict( self, @@ -223,6 +235,13 @@ def predict( use_doc_orientation_classify: bool = False, use_doc_unwarping: bool = False, use_textline_orientation: bool = False, + limit_side_len: int = 960, + limit_type: str = "max", + thresh: float = 0.3, + box_thresh: float = 0.6, + max_candidates: int = 1000, + unclip_ratio: float = 2.0, + use_dilation: bool = False, **kwargs, ) -> OCRResult: """Predicts OCR results for the given input. @@ -235,44 +254,56 @@ def predict( OCRResult: An iterable of OCRResult objects, each containing the predicted text and other relevant information. """ - input_params = { - "use_doc_preprocessor": self.use_doc_preprocessor, - "use_doc_orientation_classify": False, + model_settings = { + "use_doc_orientation_classify": use_doc_orientation_classify, "use_doc_unwarping": use_doc_unwarping, - "use_textline_orientation": False, + "use_textline_orientation": use_textline_orientation, } if use_doc_orientation_classify or use_doc_unwarping: - input_params["use_doc_preprocessor"] = True + model_settings["use_doc_preprocessor"] = True else: - input_params["use_doc_preprocessor"] = False + model_settings["use_doc_preprocessor"] = False - if not self.check_input_params_valid(input_params): + if not self.check_model_settings_valid(model_settings): yield None + text_det_params = { + "limit_side_len": limit_side_len, + "limit_type": limit_type, + "thresh": thresh, + "box_thresh": box_thresh, + "max_candidates": max_candidates, + "unclip_ratio": unclip_ratio, + "use_dilation": use_dilation, + } + for img_id, batch_data in enumerate(self.batch_sampler(input)): image_array = self.img_reader(batch_data)[0] img_id += 1 - doc_preprocessor_res, doc_preprocessor_image = ( - self.predict_doc_preprocessor_res(image_array, input_params) + doc_preprocessor_res = self.predict_doc_preprocessor_res( + image_array, model_settings ) + doc_preprocessor_image = doc_preprocessor_res["output_img"] - det_res = next(self.text_det_model(doc_preprocessor_image)) + det_res = next( + self.text_det_model(doc_preprocessor_image, **text_det_params) + ) dt_polys = det_res["dt_polys"] dt_scores = det_res["dt_scores"] - ########## [TODO] Need to confirm filtering thresholds for detection and recognition modules - dt_polys = self._sort_boxes(dt_polys) single_img_res = { "input_path": input, + # TODO: `doc_preprocessor_image` parameter does not need to be retained here, it requires further confirmation. "doc_preprocessor_image": doc_preprocessor_image, "doc_preprocessor_res": doc_preprocessor_res, "dt_polys": dt_polys, "img_id": img_id, - "input_params": input_params, + "input_params": model_settings, + "text_det_params": text_det_params, "text_type": self.text_type, } @@ -283,13 +314,14 @@ def predict( self._crop_by_polys(doc_preprocessor_image, dt_polys) ) # use textline orientation model - if input_params["use_textline_orientation"]: + if model_settings["use_textline_orientation"]: angles = [ textline_angle_info["class_ids"][0] for textline_angle_info in self.textline_orientation_model( all_subs_of_img ) ] + single_img_res["textline_orientation_angle"] = angles all_subs_of_img = self.rotate_image(all_subs_of_img, angles) for rec_res in self.text_rec_model(all_subs_of_img): diff --git a/paddlex/inference/pipelines_new/pp_chatocr/pipeline_v3.py b/paddlex/inference/pipelines_new/pp_chatocr/pipeline_v3.py index b930ef569..489dc06b9 100644 --- a/paddlex/inference/pipelines_new/pp_chatocr/pipeline_v3.py +++ b/paddlex/inference/pipelines_new/pp_chatocr/pipeline_v3.py @@ -480,7 +480,7 @@ def chat( key_list = self.format_key(key_list) key_list_ori = key_list.copy() if len(key_list) == 0: - return {"error": "输入的key_list无效!"} + return {"chat_res": "Error:输入的key_list无效!"} if not isinstance(visual_info, list): visual_info_list = [visual_info] diff --git a/paddlex/inference/pipelines_new/pp_chatocr/pipeline_v4.py b/paddlex/inference/pipelines_new/pp_chatocr/pipeline_v4.py index e3b80c467..85e30e6ae 100644 --- a/paddlex/inference/pipelines_new/pp_chatocr/pipeline_v4.py +++ b/paddlex/inference/pipelines_new/pp_chatocr/pipeline_v4.py @@ -14,7 +14,9 @@ from typing import Any, Dict, Optional import re +import cv2 import json +import base64 import numpy as np import copy from .pipeline_base import PP_ChatOCR_Pipeline @@ -99,6 +101,8 @@ def inintial_predictor(self, config: dict) -> None: if "use_mllm_predict" in config: self.use_mllm_predict = config["use_mllm_predict"] if self.use_mllm_predict: + mllm_chat_bot_config = config["SubModules"]["MLLM_Chat"] + self.mllm_chat_bot = create_chat_bot(mllm_chat_bot_config) ensemble_pe_config = config["SubModules"]["PromptEngneering"]["Ensemble"] self.ensemble_pe = create_prompt_engeering(ensemble_pe_config) return @@ -380,6 +384,47 @@ def format_key(self, key_list: str | list[str]) -> list[str]: return [] + def mllm_pred( + self, + input: str | np.ndarray, + key_list, + **kwargs, + ) -> dict: + key_list = self.format_key(key_list) + if len(key_list) == 0: + return {"mllm_res": "Error:输入的key_list无效!"} + + if isinstance(input, list): + logging.error("Input is a list, but it's not supported here.") + return {"mllm_res": "Error:Input is a list, but it's not supported here!"} + image_array_list = self.img_reader([input]) + if ( + isinstance(input, str) + and input.endswith(".pdf") + and len(image_array_list) > 1 + ): + logging.error("The input with PDF should have only one page.") + return {"mllm_res": "Error:The input with PDF should have only one page!"} + + for image_array in image_array_list: + + assert len(image_array.shape) == 3 + image_string = cv2.imencode(".jpg", image_array)[1].tostring() + image_base64 = base64.b64encode(image_string).decode("utf-8") + result = {} + for key in key_list: + prompt = ( + str(key) + + "\n请用图片中完整出现的内容回答,可以是单词、短语或句子,针对问题回答尽可能详细和完整,并保持格式、单位、符号和标点都与图片中的文字内容完全一致。" + ) + mllm_chat_bot_result = self.mllm_chat_bot.generate_chat_results( + prompt=prompt, image=image_base64 + ) + if mllm_chat_bot_result is None: + return {"mllm_res": "大模型调用失败"} + result[key] = mllm_chat_bot_result + return {"mllm_res": result} + def generate_and_merge_chat_results( self, prompt: str, key_list: list, final_results: dict, failed_results: list ) -> None: @@ -524,6 +569,7 @@ def chat( table_few_shot_demo_text_content: str = None, table_few_shot_demo_key_value_list: str = None, mllm_predict_dict: dict = None, + mllm_integration_strategy: str = "integration", ) -> dict: """ Generates chat results based on the provided key list and visual information. @@ -545,6 +591,7 @@ def chat( table_few_shot_demo_text_content (str): The text content for table few-shot demos. table_few_shot_demo_key_value_list (str): The key-value list for table few-shot demos. mllm_predict_dict (dict): The dictionary of mLLM predicts. + mllm_integration_strategy(str): The integration strategy of mLLM and LLM, defaults to "integration", options are "integration", "llm_only" and "mllm_only". Returns: dict: A dictionary containing the chat results. """ @@ -552,7 +599,7 @@ def chat( key_list = self.format_key(key_list) key_list_ori = key_list.copy() if len(key_list) == 0: - return {"error": "输入的key_list无效!"} + return {"chat_res": "Error:输入的key_list无效!"} if not isinstance(visual_info, list): visual_info_list = [visual_info] @@ -620,10 +667,17 @@ def chat( prompt, key_list, final_results, failed_results ) - if self.use_mllm_predict: - final_predict_dict = self.ensemble_ocr_llm_mllm( - key_list_ori, final_results, mllm_predict_dict - ) + if self.use_mllm_predict and mllm_predict_dict != "llm_only": + if mllm_integration_strategy == "integration": + final_predict_dict = self.ensemble_ocr_llm_mllm( + key_list_ori, final_results, mllm_predict_dict + ) + elif mllm_integration_strategy == "mllm_only": + final_predict_dict = mllm_predict_dict + else: + return { + "chat_res": f"Error:Unsupported mllm_integration_strategy {mllm_integration_strategy}, only support 'integration', 'llm_only' and 'mllm_only'!" + } else: final_predict_dict = final_results return {"chat_res": final_predict_dict} diff --git a/paddlex/inference/utils/io/readers.py b/paddlex/inference/utils/io/readers.py index f6aae1578..2d62875e5 100644 --- a/paddlex/inference/utils/io/readers.py +++ b/paddlex/inference/utils/io/readers.py @@ -22,7 +22,6 @@ import numpy as np import yaml import soundfile -import decord import random import platform import importlib diff --git a/paddlex/repo_manager/utils.py b/paddlex/repo_manager/utils.py index fc7313d85..66e2f4de9 100644 --- a/paddlex/repo_manager/utils.py +++ b/paddlex/repo_manager/utils.py @@ -103,16 +103,16 @@ def install_packages_using_pip( def install_external_deps(repo_name, repo_root): """install paddle repository custom dependencies""" - gcc_version = ( - subprocess.check_output(["gcc", "--version"]).decode("utf-8").split()[2] - ) + + def get_gcc_version(): + return subprocess.check_output(["gcc", "--version"]).decode("utf-8").split()[2] if repo_name == "PaddleDetection": if os.path.exists(os.path.join(repo_root, "ppdet", "ext_op")): """Install custom op for rotated object detection""" if ( PLATFORM == "Linux" - and _compare_version(gcc_version, "8.2.0") >= 0 + and _compare_version(get_gcc_version(), "8.2.0") >= 0 and "gpu" in get_device_type() and ( paddle.is_compiled_with_cuda()