Skip to content

Commit

Permalink
change to dict type and support to save multi result files
Browse files Browse the repository at this point in the history
  • Loading branch information
TingquanGao committed Jan 10, 2025
1 parent 36b138a commit 07969ad
Show file tree
Hide file tree
Showing 16 changed files with 128 additions and 91 deletions.
177 changes: 109 additions & 68 deletions paddlex/inference/common/result/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ class StrMixin:
"""Mixin class for adding string conversion capabilities."""

@property
def str(self) -> str:
def str(self) -> Dict[str, str]:
"""Property to get the string representation of the result.
Returns:
str: The str type string representation of the result.
Dict[str, str]: The string representation of the result.
"""

return self._to_str(self)
Expand All @@ -54,7 +54,7 @@ def _to_str(
json_format: bool = False,
indent: int = 4,
ensure_ascii: bool = False,
) -> str:
):
"""Convert the given result data to a string representation.
Args:
Expand All @@ -64,12 +64,14 @@ def _to_str(
ensure_ascii (bool): If True, ensure all characters are ASCII. Default is False.
Returns:
str: The string representation of the data.
Dict[str, str]: The string representation of the result.
"""
if json_format:
return json.dumps(data.json, indent=indent, ensure_ascii=ensure_ascii)
return {
"res": json.dumps(data.json, indent=indent, ensure_ascii=ensure_ascii)
}
else:
return str(data)
return {"res": str(data)}

def print(
self, json_format: bool = False, indent: int = 4, ensure_ascii: bool = False
Expand All @@ -84,7 +86,7 @@ def print(
str_ = self._to_str(
self, json_format=json_format, indent=indent, ensure_ascii=ensure_ascii
)
logging.info(str_)
logging.info(str_["res"])


class JsonMixin:
Expand All @@ -94,11 +96,11 @@ def __init__(self) -> None:
self._json_writer = JsonWriter()
self._save_funcs.append(self.save_to_json)

def _to_json(self) -> Dict[str, Any]:
def _to_json(self) -> Dict[str, Dict[str, Any]]:
"""Convert the object to a JSON-serializable format.
Returns:
Dict[str, Any]: A dictionary representation of the object that is JSON-serializable.
Dict[str, Dict[str, Any]]: A dictionary representation of the object that is JSON-serializable.
"""

def _format_data(obj):
Expand All @@ -125,14 +127,14 @@ def _format_data(obj):
else:
return obj

return _format_data(copy.deepcopy(self))
return {"res": _format_data(copy.deepcopy(self))}

@property
def json(self) -> Dict[str, Any]:
def json(self) -> Dict[str, Dict[str, Any]]:
"""Property to get the JSON representation of the result.
Returns:
Dict[str, Any]: The dict type JSON representation of the result.
Dict[str, Dict[str, Any]]: The dict type JSON representation of the result.
"""

return self._to_json()
Expand Down Expand Up @@ -160,16 +162,28 @@ def _is_json_file(file_path):
return mime_type is not None and mime_type == "application/json"

if not _is_json_file(save_path):
save_path = Path(save_path) / f"{Path(self['input_path']).stem}.json"
save_path = save_path.as_posix()
self._json_writer.write(
save_path,
self.json,
indent=indent,
ensure_ascii=ensure_ascii,
*args,
**kwargs,
)
fp = Path(self["input_path"])
stem = fp.stem
suffix = fp.suffix
base_save_path = Path(save_path)
for key in self.json:
save_path = base_save_path / f"{stem}_{key}.json"
self._json_writer.write(
save_path.as_posix(), self.json[key], *args, **kwargs
)
else:
if len(self.json) > 1:
logging.warning(
f"The result has multiple json files need to be saved. But the `save_path` has been specfied as `{save_path}`!"
)
self._json_writer.write(
save_path,
self.json[list(self.json.keys())[0]],
indent=indent,
ensure_ascii=ensure_ascii,
*args,
**kwargs,
)


class Base64Mixin:
Expand All @@ -186,21 +200,21 @@ def __init__(self, *args: List, **kwargs: Dict) -> None:
self._save_funcs.append(self.save_to_base64)

@abstractmethod
def _to_base64(self) -> str:
def _to_base64(self) -> Dict[str, str]:
"""Abstract method to convert the result to Base64.
Returns:
str: The str type Base64 representation result.
Dict[str, str]: The str type Base64 representation result.
"""
raise NotImplementedError

@property
def base64(self) -> str:
def base64(self) -> Dict[str, str]:
"""
Property that returns the Base64 encoded content.
Returns:
str: The base64 representation of the result.
Dict[str, str]: The base64 representation of the result.
"""
return self._to_base64()

Expand All @@ -213,13 +227,24 @@ def save_to_base64(self, save_path: str, *args: List, **kwargs: Dict) -> None:
*args: Additional positional arguments that will be passed to the base64 writer.
**kwargs: Additional keyword arguments that will be passed to the base64 writer.
"""

if not str(save_path).lower().endswith((".b64")):
fp = Path(self["input_path"])
save_path = Path(save_path) / f"{fp.stem}{fp.suffix}"
stem = fp.stem
suffix = fp.suffix
base_save_path = Path(save_path)
for key in self.base64:
save_path = base_save_path / f"{stem}_{key}.b64"
self._base64_writer.write(
save_path.as_posix(), self.base64[key], *args, **kwargs
)
else:
save_path = Path(save_path)
self._base64_writer.write(save_path.as_posix(), self.base64, *args, **kwargs)
if len(self.base64) > 1:
logging.warning(
f"The result has multiple base64 files need to be saved. But the `save_path` has been specfied as `{save_path}`!"
)
self._base64_writer.write(
save_path, self.base64[list(self.base64.keys())[0]], *args, **kwargs
)


class ImgMixin:
Expand All @@ -237,20 +262,20 @@ def __init__(self, backend: str = "pillow", *args: List, **kwargs: Dict) -> None
self._save_funcs.append(self.save_to_img)

@abstractmethod
def _to_img(self) -> Union[Image.Image, Dict[str, Image.Image]]:
def _to_img(self) -> Dict[str, Image.Image]:
"""Abstract method to convert the result to an image.
Returns:
Union[Image.Image, Dict[str, Image.Image]]: The image representation result.
Dict[str, Image.Image]: The image representation result.
"""
raise NotImplementedError

@property
def img(self) -> Union[Image.Image, Dict[str, Image.Image]]:
def img(self) -> Dict[str, Image.Image]:
"""Property to get the image representation of the result.
Returns:
Union[Image.Image, Dict[str, Image.Image]]: The image representation of the result.
Dict[str, Image.Image]: The image representation of the result.
"""
return self._to_img()

Expand All @@ -267,24 +292,24 @@ def _is_image_file(file_path):
mime_type, _ = mimetypes.guess_type(file_path)
return mime_type is not None and mime_type.startswith("image/")

img = self.img
if isinstance(img, dict):
if not _is_image_file(save_path):
fp = Path(self["input_path"])
stem = fp.stem
suffix = fp.suffix
else:
stem = save_path.stem
suffix = save_path.suffix
if not _is_image_file(save_path):
fp = Path(self["input_path"])
stem = fp.stem
suffix = fp.suffix
base_save_path = Path(save_path)
for key in img:
for key in self.img:
save_path = base_save_path / f"{stem}_{key}{suffix}"
self._img_writer.write(save_path.as_posix(), img[key], *args, **kwargs)
self._img_writer.write(
save_path.as_posix(), self.img[key], *args, **kwargs
)
else:
if not _is_image_file(save_path):
fp = Path(self["input_path"])
save_path = Path(save_path) / f"{fp.stem}{fp.suffix}"
self._img_writer.write(save_path.as_posix(), img, *args, **kwargs)
if len(self.img) > 1:
logging.warning(
f"The result has multiple img files need to be saved. But the `save_path` has been specfied as `{save_path}`!"
)
self._img_writer.write(
save_path, self.img[list(self.img.keys())[0]], *args, **kwargs
)


class CSVMixin:
Expand All @@ -304,20 +329,20 @@ def __init__(self, backend: str = "pandas", *args: List, **kwargs: Dict) -> None
self._save_funcs.append(self.save_to_csv)

@property
def csv(self) -> pd.DataFrame:
def csv(self) -> Dict[str, pd.DataFrame]:
"""Property to get the pandas Dataframe representation of the result.
Returns:
pandas.DataFrame: The pandas.DataFrame representation of the result.
Dict[str, pd.DataFrame]: The pandas.DataFrame representation of the result.
"""
return self._to_csv()

@abstractmethod
def _to_csv(self) -> pd.DataFrame:
def _to_csv(self) -> Dict[str, pd.DataFrame]:
"""Abstract method to convert the result to pandas.DataFrame.
Returns:
pandas.DataFrame: The pandas.DataFrame representation result.
Dict[str, pd.DataFrame]: The pandas.DataFrame representation result.
"""
raise NotImplementedError

Expand All @@ -330,11 +355,28 @@ def save_to_csv(self, save_path: str, *args: List, **kwargs: Dict) -> None:
*args: Optional positional arguments to pass to the CSV writer's write method.
**kwargs: Optional keyword arguments to pass to the CSV writer's write method.
"""
if not str(save_path).endswith(".csv"):
save_path = Path(save_path) / f"{Path(self['input_path']).stem}.csv"

def _is_csv_file(file_path):
mime_type, _ = mimetypes.guess_type(file_path)
return mime_type is not None and mime_type == "text/csv"

if not _is_csv_file(save_path):
fp = Path(self["input_path"])
stem = fp.stem
base_save_path = Path(save_path)
for key in self.csv:
save_path = base_save_path / f"{stem}_{key}.csv"
self._csv_writer.write(
save_path.as_posix(), self.csv[key], *args, **kwargs
)
else:
save_path = Path(save_path)
self._csv_writer.write(save_path.as_posix(), self.csv, *args, **kwargs)
if len(self.csv) > 1:
logging.warning(
f"The result has multiple csv files need to be saved. But the `save_path` has been specfied as `{save_path}`!"
)
self._csv_writer.write(
save_path, self.csv[list(self.csv.keys())[0]], *args, **kwargs
)


class HtmlMixin:
Expand All @@ -352,7 +394,7 @@ def __init__(self, *args: List, **kwargs: Dict) -> None:
self._save_funcs.append(self.save_to_html)

@property
def html(self) -> str:
def html(self) -> Dict[str, str]:
"""Property to get the HTML representation of the result.
Returns:
Expand All @@ -361,11 +403,11 @@ def html(self) -> str:
return self._to_html()

@abstractmethod
def _to_html(self) -> str:
def _to_html(self) -> Dict[str, str]:
"""Abstract method to convert the result to str type HTML representation.
Returns:
str: The str type HTML representation result.
Dict[str, str]: The str type HTML representation result.
"""
raise NotImplementedError

Expand All @@ -381,7 +423,7 @@ def save_to_html(self, save_path: str, *args: List, **kwargs: Dict) -> None:
save_path = Path(save_path) / f"{Path(self['input_path']).stem}.html"
else:
save_path = Path(save_path)
self._html_writer.write(save_path.as_posix(), self.html, *args, **kwargs)
self._html_writer.write(save_path.as_posix(), self.html["res"], *args, **kwargs)


class XlsxMixin:
Expand All @@ -398,20 +440,20 @@ def __init__(self, *args: List, **kwargs: Dict) -> None:
self._save_funcs.append(self.save_to_xlsx)

@property
def xlsx(self) -> str:
def xlsx(self) -> Dict[str, str]:
"""Property to get the XLSX representation of the result.
Returns:
str: The str type XLSX representation of the result.
Dict[str, str]: The str type XLSX representation of the result.
"""
return self._to_xlsx()

@abstractmethod
def _to_xlsx(self) -> str:
def _to_xlsx(self) -> Dict[str, str]:
"""Abstract method to convert the result to str type XLSX representation.
Returns:
str: The str type HTML representation result.
Dict[str, str]: The str type HTML representation result.
"""
raise NotImplementedError

Expand Down Expand Up @@ -442,12 +484,11 @@ def _to_video(self):

@property
def video(self):
video = self._to_video()
return video
return self._to_video()

def save_to_video(self, save_path, *args, **kwargs):
video_writer = VideoWriter(backend=self._backend, *args, **kwargs)
if not str(save_path).lower().endswith((".mp4", ".avi", ".mkv", ".webm")):
fp = Path(self["input_path"])
save_path = Path(save_path) / f"{fp.stem}{fp.suffix}"
_save_list_data(video_writer.write, save_path, self.video, *args, **kwargs)
video_writer.write(save_path.as_posix(), self.video["video"], *args, **kwargs)
2 changes: 1 addition & 1 deletion paddlex/inference/models_new/anomaly_detection/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _to_img(self):
"""apply"""
seg_map = self["pred"]
pc_map = self.get_pseudo_color_map(seg_map[0])
return pc_map
return {"res": pc_map}

def get_pseudo_color_map(self, pred):
"""get_pseudo_color_map"""
Expand Down
Loading

0 comments on commit 07969ad

Please sign in to comment.