From 07969ad6fb5587ed298417065ab8dd1384e6e85a Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Thu, 9 Jan 2025 12:42:51 +0000 Subject: [PATCH] change to dict type and support to save multi result files --- paddlex/inference/common/result/mixin.py | 177 +++++++++++------- .../models_new/anomaly_detection/result.py | 2 +- .../models_new/formula_recognition/result.py | 6 +- .../models_new/image_classification/result.py | 2 +- .../models_new/image_feature/result.py | 10 +- .../image_multilabel_classification/result.py | 2 +- .../models_new/image_unwarping/result.py | 2 +- .../instance_segmentation/result.py | 2 +- .../models_new/object_detection/result.py | 2 +- .../semantic_segmentation/result.py | 2 +- .../table_structure_recognition/result.py | 2 +- .../models_new/text_detection/result.py | 2 +- .../models_new/text_recognition/result.py | 2 +- .../models_new/ts_anomaly_detection/result.py | 2 +- .../models_new/ts_classification/result.py | 2 +- .../models_new/ts_forecasting/result.py | 2 +- 16 files changed, 128 insertions(+), 91 deletions(-) diff --git a/paddlex/inference/common/result/mixin.py b/paddlex/inference/common/result/mixin.py index c618aed6e..978a50b7b 100644 --- a/paddlex/inference/common/result/mixin.py +++ b/paddlex/inference/common/result/mixin.py @@ -39,11 +39,11 @@ class StrMixin: """Mixin class for adding string conversion capabilities.""" @property - def str(self) -> str: + def str(self) -> Dict[str, str]: """Property to get the string representation of the result. Returns: - str: The str type string representation of the result. + Dict[str, str]: The string representation of the result. """ return self._to_str(self) @@ -54,7 +54,7 @@ def _to_str( json_format: bool = False, indent: int = 4, ensure_ascii: bool = False, - ) -> str: + ): """Convert the given result data to a string representation. Args: @@ -64,12 +64,14 @@ def _to_str( ensure_ascii (bool): If True, ensure all characters are ASCII. Default is False. Returns: - str: The string representation of the data. + Dict[str, str]: The string representation of the result. """ if json_format: - return json.dumps(data.json, indent=indent, ensure_ascii=ensure_ascii) + return { + "res": json.dumps(data.json, indent=indent, ensure_ascii=ensure_ascii) + } else: - return str(data) + return {"res": str(data)} def print( self, json_format: bool = False, indent: int = 4, ensure_ascii: bool = False @@ -84,7 +86,7 @@ def print( str_ = self._to_str( self, json_format=json_format, indent=indent, ensure_ascii=ensure_ascii ) - logging.info(str_) + logging.info(str_["res"]) class JsonMixin: @@ -94,11 +96,11 @@ def __init__(self) -> None: self._json_writer = JsonWriter() self._save_funcs.append(self.save_to_json) - def _to_json(self) -> Dict[str, Any]: + def _to_json(self) -> Dict[str, Dict[str, Any]]: """Convert the object to a JSON-serializable format. Returns: - Dict[str, Any]: A dictionary representation of the object that is JSON-serializable. + Dict[str, Dict[str, Any]]: A dictionary representation of the object that is JSON-serializable. """ def _format_data(obj): @@ -125,14 +127,14 @@ def _format_data(obj): else: return obj - return _format_data(copy.deepcopy(self)) + return {"res": _format_data(copy.deepcopy(self))} @property - def json(self) -> Dict[str, Any]: + def json(self) -> Dict[str, Dict[str, Any]]: """Property to get the JSON representation of the result. Returns: - Dict[str, Any]: The dict type JSON representation of the result. + Dict[str, Dict[str, Any]]: The dict type JSON representation of the result. """ return self._to_json() @@ -160,16 +162,28 @@ def _is_json_file(file_path): return mime_type is not None and mime_type == "application/json" if not _is_json_file(save_path): - save_path = Path(save_path) / f"{Path(self['input_path']).stem}.json" - save_path = save_path.as_posix() - self._json_writer.write( - save_path, - self.json, - indent=indent, - ensure_ascii=ensure_ascii, - *args, - **kwargs, - ) + fp = Path(self["input_path"]) + stem = fp.stem + suffix = fp.suffix + base_save_path = Path(save_path) + for key in self.json: + save_path = base_save_path / f"{stem}_{key}.json" + self._json_writer.write( + save_path.as_posix(), self.json[key], *args, **kwargs + ) + else: + if len(self.json) > 1: + logging.warning( + f"The result has multiple json files need to be saved. But the `save_path` has been specfied as `{save_path}`!" + ) + self._json_writer.write( + save_path, + self.json[list(self.json.keys())[0]], + indent=indent, + ensure_ascii=ensure_ascii, + *args, + **kwargs, + ) class Base64Mixin: @@ -186,21 +200,21 @@ def __init__(self, *args: List, **kwargs: Dict) -> None: self._save_funcs.append(self.save_to_base64) @abstractmethod - def _to_base64(self) -> str: + def _to_base64(self) -> Dict[str, str]: """Abstract method to convert the result to Base64. Returns: - str: The str type Base64 representation result. + Dict[str, str]: The str type Base64 representation result. """ raise NotImplementedError @property - def base64(self) -> str: + def base64(self) -> Dict[str, str]: """ Property that returns the Base64 encoded content. Returns: - str: The base64 representation of the result. + Dict[str, str]: The base64 representation of the result. """ return self._to_base64() @@ -213,13 +227,24 @@ def save_to_base64(self, save_path: str, *args: List, **kwargs: Dict) -> None: *args: Additional positional arguments that will be passed to the base64 writer. **kwargs: Additional keyword arguments that will be passed to the base64 writer. """ - if not str(save_path).lower().endswith((".b64")): fp = Path(self["input_path"]) - save_path = Path(save_path) / f"{fp.stem}{fp.suffix}" + stem = fp.stem + suffix = fp.suffix + base_save_path = Path(save_path) + for key in self.base64: + save_path = base_save_path / f"{stem}_{key}.b64" + self._base64_writer.write( + save_path.as_posix(), self.base64[key], *args, **kwargs + ) else: - save_path = Path(save_path) - self._base64_writer.write(save_path.as_posix(), self.base64, *args, **kwargs) + if len(self.base64) > 1: + logging.warning( + f"The result has multiple base64 files need to be saved. But the `save_path` has been specfied as `{save_path}`!" + ) + self._base64_writer.write( + save_path, self.base64[list(self.base64.keys())[0]], *args, **kwargs + ) class ImgMixin: @@ -237,20 +262,20 @@ def __init__(self, backend: str = "pillow", *args: List, **kwargs: Dict) -> None self._save_funcs.append(self.save_to_img) @abstractmethod - def _to_img(self) -> Union[Image.Image, Dict[str, Image.Image]]: + def _to_img(self) -> Dict[str, Image.Image]: """Abstract method to convert the result to an image. Returns: - Union[Image.Image, Dict[str, Image.Image]]: The image representation result. + Dict[str, Image.Image]: The image representation result. """ raise NotImplementedError @property - def img(self) -> Union[Image.Image, Dict[str, Image.Image]]: + def img(self) -> Dict[str, Image.Image]: """Property to get the image representation of the result. Returns: - Union[Image.Image, Dict[str, Image.Image]]: The image representation of the result. + Dict[str, Image.Image]: The image representation of the result. """ return self._to_img() @@ -267,24 +292,24 @@ def _is_image_file(file_path): mime_type, _ = mimetypes.guess_type(file_path) return mime_type is not None and mime_type.startswith("image/") - img = self.img - if isinstance(img, dict): - if not _is_image_file(save_path): - fp = Path(self["input_path"]) - stem = fp.stem - suffix = fp.suffix - else: - stem = save_path.stem - suffix = save_path.suffix + if not _is_image_file(save_path): + fp = Path(self["input_path"]) + stem = fp.stem + suffix = fp.suffix base_save_path = Path(save_path) - for key in img: + for key in self.img: save_path = base_save_path / f"{stem}_{key}{suffix}" - self._img_writer.write(save_path.as_posix(), img[key], *args, **kwargs) + self._img_writer.write( + save_path.as_posix(), self.img[key], *args, **kwargs + ) else: - if not _is_image_file(save_path): - fp = Path(self["input_path"]) - save_path = Path(save_path) / f"{fp.stem}{fp.suffix}" - self._img_writer.write(save_path.as_posix(), img, *args, **kwargs) + if len(self.img) > 1: + logging.warning( + f"The result has multiple img files need to be saved. But the `save_path` has been specfied as `{save_path}`!" + ) + self._img_writer.write( + save_path, self.img[list(self.img.keys())[0]], *args, **kwargs + ) class CSVMixin: @@ -304,20 +329,20 @@ def __init__(self, backend: str = "pandas", *args: List, **kwargs: Dict) -> None self._save_funcs.append(self.save_to_csv) @property - def csv(self) -> pd.DataFrame: + def csv(self) -> Dict[str, pd.DataFrame]: """Property to get the pandas Dataframe representation of the result. Returns: - pandas.DataFrame: The pandas.DataFrame representation of the result. + Dict[str, pd.DataFrame]: The pandas.DataFrame representation of the result. """ return self._to_csv() @abstractmethod - def _to_csv(self) -> pd.DataFrame: + def _to_csv(self) -> Dict[str, pd.DataFrame]: """Abstract method to convert the result to pandas.DataFrame. Returns: - pandas.DataFrame: The pandas.DataFrame representation result. + Dict[str, pd.DataFrame]: The pandas.DataFrame representation result. """ raise NotImplementedError @@ -330,11 +355,28 @@ def save_to_csv(self, save_path: str, *args: List, **kwargs: Dict) -> None: *args: Optional positional arguments to pass to the CSV writer's write method. **kwargs: Optional keyword arguments to pass to the CSV writer's write method. """ - if not str(save_path).endswith(".csv"): - save_path = Path(save_path) / f"{Path(self['input_path']).stem}.csv" + + def _is_csv_file(file_path): + mime_type, _ = mimetypes.guess_type(file_path) + return mime_type is not None and mime_type == "text/csv" + + if not _is_csv_file(save_path): + fp = Path(self["input_path"]) + stem = fp.stem + base_save_path = Path(save_path) + for key in self.csv: + save_path = base_save_path / f"{stem}_{key}.csv" + self._csv_writer.write( + save_path.as_posix(), self.csv[key], *args, **kwargs + ) else: - save_path = Path(save_path) - self._csv_writer.write(save_path.as_posix(), self.csv, *args, **kwargs) + if len(self.csv) > 1: + logging.warning( + f"The result has multiple csv files need to be saved. But the `save_path` has been specfied as `{save_path}`!" + ) + self._csv_writer.write( + save_path, self.csv[list(self.csv.keys())[0]], *args, **kwargs + ) class HtmlMixin: @@ -352,7 +394,7 @@ def __init__(self, *args: List, **kwargs: Dict) -> None: self._save_funcs.append(self.save_to_html) @property - def html(self) -> str: + def html(self) -> Dict[str, str]: """Property to get the HTML representation of the result. Returns: @@ -361,11 +403,11 @@ def html(self) -> str: return self._to_html() @abstractmethod - def _to_html(self) -> str: + def _to_html(self) -> Dict[str, str]: """Abstract method to convert the result to str type HTML representation. Returns: - str: The str type HTML representation result. + Dict[str, str]: The str type HTML representation result. """ raise NotImplementedError @@ -381,7 +423,7 @@ def save_to_html(self, save_path: str, *args: List, **kwargs: Dict) -> None: save_path = Path(save_path) / f"{Path(self['input_path']).stem}.html" else: save_path = Path(save_path) - self._html_writer.write(save_path.as_posix(), self.html, *args, **kwargs) + self._html_writer.write(save_path.as_posix(), self.html["res"], *args, **kwargs) class XlsxMixin: @@ -398,20 +440,20 @@ def __init__(self, *args: List, **kwargs: Dict) -> None: self._save_funcs.append(self.save_to_xlsx) @property - def xlsx(self) -> str: + def xlsx(self) -> Dict[str, str]: """Property to get the XLSX representation of the result. Returns: - str: The str type XLSX representation of the result. + Dict[str, str]: The str type XLSX representation of the result. """ return self._to_xlsx() @abstractmethod - def _to_xlsx(self) -> str: + def _to_xlsx(self) -> Dict[str, str]: """Abstract method to convert the result to str type XLSX representation. Returns: - str: The str type HTML representation result. + Dict[str, str]: The str type HTML representation result. """ raise NotImplementedError @@ -442,12 +484,11 @@ def _to_video(self): @property def video(self): - video = self._to_video() - return video + return self._to_video() def save_to_video(self, save_path, *args, **kwargs): video_writer = VideoWriter(backend=self._backend, *args, **kwargs) if not str(save_path).lower().endswith((".mp4", ".avi", ".mkv", ".webm")): fp = Path(self["input_path"]) save_path = Path(save_path) / f"{fp.stem}{fp.suffix}" - _save_list_data(video_writer.write, save_path, self.video, *args, **kwargs) + video_writer.write(save_path.as_posix(), self.video["video"], *args, **kwargs) diff --git a/paddlex/inference/models_new/anomaly_detection/result.py b/paddlex/inference/models_new/anomaly_detection/result.py index 7ec034719..c3cdae602 100644 --- a/paddlex/inference/models_new/anomaly_detection/result.py +++ b/paddlex/inference/models_new/anomaly_detection/result.py @@ -26,7 +26,7 @@ def _to_img(self): """apply""" seg_map = self["pred"] pc_map = self.get_pseudo_color_map(seg_map[0]) - return pc_map + return {"res": pc_map} def get_pseudo_color_map(self, pred): """get_pseudo_color_map""" diff --git a/paddlex/inference/models_new/formula_recognition/result.py b/paddlex/inference/models_new/formula_recognition/result.py index 23e6ba94e..16edaf7f2 100644 --- a/paddlex/inference/models_new/formula_recognition/result.py +++ b/paddlex/inference/models_new/formula_recognition/result.py @@ -56,7 +56,7 @@ def _to_img( logging.warning( "Please refer to 2.3 Formula Recognition Pipeline Visualization in Formula Recognition Pipeline Tutorial to install the LaTeX rendering engine at first." ) - return image + return {"res": image} rec_formula = str(self["rec_formula"]) image = np.array(image.convert("RGB")) @@ -83,10 +83,10 @@ def _to_img( ) new_image.paste(image, (0, 0)) new_image.paste(img_formula, (image.width + 10, 0)) - return new_image + return {"res": new_image} except subprocess.CalledProcessError as e: logging.warning("Syntax error detected in formula, rendering failed.") - return image + return {"res": image} def get_align_equation(equation: str) -> str: diff --git a/paddlex/inference/models_new/image_classification/result.py b/paddlex/inference/models_new/image_classification/result.py index 10e2c5b1a..3bd1ce0b2 100644 --- a/paddlex/inference/models_new/image_classification/result.py +++ b/paddlex/inference/models_new/image_classification/result.py @@ -66,7 +66,7 @@ def _to_img(self): text_x = rect_left + 3 text_y = rect_top draw.text((text_x, text_y), label_str, fill=font_color, font=font) - return image + return {"res": image} def _get_font_colormap(self, color_index): """ diff --git a/paddlex/inference/models_new/image_feature/result.py b/paddlex/inference/models_new/image_feature/result.py index 4f9351882..07816f723 100644 --- a/paddlex/inference/models_new/image_feature/result.py +++ b/paddlex/inference/models_new/image_feature/result.py @@ -14,12 +14,8 @@ from PIL import Image -from ...common.result import BaseCVResult +from ...common.result import BaseResult -class IdentityResult(BaseCVResult): - - def _to_img(self): - """This module does not support visualization; it simply outputs the input images""" - image = Image.fromarray(self["input_img"]) - return image +class IdentityResult(BaseResult): + pass diff --git a/paddlex/inference/models_new/image_multilabel_classification/result.py b/paddlex/inference/models_new/image_multilabel_classification/result.py index c90773857..76d5196a2 100644 --- a/paddlex/inference/models_new/image_multilabel_classification/result.py +++ b/paddlex/inference/models_new/image_multilabel_classification/result.py @@ -70,7 +70,7 @@ def _to_img(self): fill=font_color, font=font, ) - return new_image + return {"res": new_image} def _get_font_colormap(self, color_index): """ diff --git a/paddlex/inference/models_new/image_unwarping/result.py b/paddlex/inference/models_new/image_unwarping/result.py index 731f4aa55..8cb181703 100644 --- a/paddlex/inference/models_new/image_unwarping/result.py +++ b/paddlex/inference/models_new/image_unwarping/result.py @@ -31,7 +31,7 @@ class DocTrResult(BaseCVResult): def _to_img(self) -> np.ndarray: result = np.array(self["doctr_img"]) - return result + return {"res": result} def _to_str(self, _, *args, **kwargs): data = copy.deepcopy(self) diff --git a/paddlex/inference/models_new/instance_segmentation/result.py b/paddlex/inference/models_new/instance_segmentation/result.py index eb4320ad4..e56976018 100644 --- a/paddlex/inference/models_new/instance_segmentation/result.py +++ b/paddlex/inference/models_new/instance_segmentation/result.py @@ -147,7 +147,7 @@ def _to_img(self): else: image = draw_segm(image, masks, boxes) - return image + return {"res": image} def _to_str(self, _, *args, **kwargs): data = copy.deepcopy(self) diff --git a/paddlex/inference/models_new/object_detection/result.py b/paddlex/inference/models_new/object_detection/result.py index 8e5965ca4..823bc5697 100644 --- a/paddlex/inference/models_new/object_detection/result.py +++ b/paddlex/inference/models_new/object_detection/result.py @@ -100,4 +100,4 @@ def _to_img(self) -> Image.Image: """apply""" boxes = self["boxes"] image = Image.fromarray(self["input_img"]) - return draw_box(image, boxes) + return {"res": draw_box(image, boxes)} diff --git a/paddlex/inference/models_new/semantic_segmentation/result.py b/paddlex/inference/models_new/semantic_segmentation/result.py index bdadf7a65..8391963e1 100644 --- a/paddlex/inference/models_new/semantic_segmentation/result.py +++ b/paddlex/inference/models_new/semantic_segmentation/result.py @@ -28,7 +28,7 @@ def _to_img(self): pc_map = self.get_pseudo_color_map(seg_map[0]) if pc_map.mode == "P": pc_map = pc_map.convert("RGB") - return pc_map + return {"res": pc_map} def get_pseudo_color_map(self, pred): """get_pseudo_color_map""" diff --git a/paddlex/inference/models_new/table_structure_recognition/result.py b/paddlex/inference/models_new/table_structure_recognition/result.py index 7fcbfb0cd..3e7c577e3 100644 --- a/paddlex/inference/models_new/table_structure_recognition/result.py +++ b/paddlex/inference/models_new/table_structure_recognition/result.py @@ -34,7 +34,7 @@ def _to_img(self): vis_img = self.draw_rectangle(image, bbox_res) else: vis_img = self.draw_bbox(image, bbox_res) - return vis_img + return {"res": vis_img} def draw_rectangle(self, image, boxes): """draw_rectangle""" diff --git a/paddlex/inference/models_new/text_detection/result.py b/paddlex/inference/models_new/text_detection/result.py index c7904d924..ab6c4ced2 100644 --- a/paddlex/inference/models_new/text_detection/result.py +++ b/paddlex/inference/models_new/text_detection/result.py @@ -30,4 +30,4 @@ def _to_img(self): for box in boxes: box = np.reshape(np.array(box).astype(int), [-1, 1, 2]).astype(np.int64) cv2.polylines(image, [box], True, (0, 0, 255), 2) - return image[:, :, ::-1] + return {"res": image[:, :, ::-1]} diff --git a/paddlex/inference/models_new/text_recognition/result.py b/paddlex/inference/models_new/text_recognition/result.py index a6bef2f82..74564bd7c 100644 --- a/paddlex/inference/models_new/text_recognition/result.py +++ b/paddlex/inference/models_new/text_recognition/result.py @@ -42,7 +42,7 @@ def _to_img(self): fill=(0, 0, 0), font=font, ) - return new_image + return {"res": new_image} def adjust_font_size(self, image_width, text, font_path): font_size = int(image_width * 0.06) diff --git a/paddlex/inference/models_new/ts_anomaly_detection/result.py b/paddlex/inference/models_new/ts_anomaly_detection/result.py index 5ef758a9f..87a97b30d 100644 --- a/paddlex/inference/models_new/ts_anomaly_detection/result.py +++ b/paddlex/inference/models_new/ts_anomaly_detection/result.py @@ -26,4 +26,4 @@ def _to_csv(self) -> Any: Returns: Any: The anomaly data formatted for CSV output, typically a DataFrame or similar structure. """ - return self["anomaly"] + return {"res": self["anomaly"]} diff --git a/paddlex/inference/models_new/ts_classification/result.py b/paddlex/inference/models_new/ts_classification/result.py index c8f9926bc..588d30116 100644 --- a/paddlex/inference/models_new/ts_classification/result.py +++ b/paddlex/inference/models_new/ts_classification/result.py @@ -26,4 +26,4 @@ def _to_csv(self) -> Any: Returns: Any: The classification data formatted for CSV output, typically a DataFrame or similar structure. """ - return self["classification"] + return {"res": self["classification"]} diff --git a/paddlex/inference/models_new/ts_forecasting/result.py b/paddlex/inference/models_new/ts_forecasting/result.py index 67996ae27..4d60b5b4e 100644 --- a/paddlex/inference/models_new/ts_forecasting/result.py +++ b/paddlex/inference/models_new/ts_forecasting/result.py @@ -26,4 +26,4 @@ def _to_csv(self) -> Any: Returns: Any: The forecast data formatted for CSV output, typically a DataFrame or similar structure. """ - return self["forecast"] + return {"res": self["forecast"]}