From 957e8050de67588bb31fa64155871209abcdb69a Mon Sep 17 00:00:00 2001 From: Dave Shoup Date: Thu, 2 Nov 2023 14:57:03 -0400 Subject: [PATCH] update `repr_llm` and add `DataFrameSummarizer` with customizable summarizing function (#323) * update repr_llm * wip * getter * compat * cleanup --- poetry.lock | 14 ++++----- pyproject.toml | 2 +- src/dx/formatters/main.py | 4 +-- src/dx/formatters/summarizing.py | 54 ++++++++++++++++++++++++++++++++ 4 files changed, 64 insertions(+), 10 deletions(-) create mode 100644 src/dx/formatters/summarizing.py diff --git a/poetry.lock b/poetry.lock index 3710d4c6..2346b60d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -3682,8 +3682,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, ] python-dateutil = ">=2.8.1" pytz = ">=2020.1" @@ -4884,13 +4884,13 @@ tune = ["pandas", "requests", "tabulate", "tensorboardX (>=1.9)"] [[package]] name = "repr-llm" -version = "0.2.1" +version = "0.3.0" description = "Creating lightweight representations of objects for Large Language Model consumption" optional = false python-versions = ">=3.9,<4.0" files = [ - {file = "repr_llm-0.2.1-py3-none-any.whl", hash = "sha256:0df19f96c6a6b03143cdb65179bb95041ced9b85955e3ed6e74954f7b10e4f83"}, - {file = "repr_llm-0.2.1.tar.gz", hash = "sha256:6762ae7d20c64c507af6084f9ca565ddfbc6b5bf771558dd2b3fda4d220c8aa7"}, + {file = "repr_llm-0.3.0-py3-none-any.whl", hash = "sha256:0dc7f15d4f2649fa79341d562af76986b5efb199606997f665230524eb9bab35"}, + {file = "repr_llm-0.3.0.tar.gz", hash = "sha256:5ed183ba09f365692a9ea00ab54c13118ddcabc889b8fd5429c32c24a25b520c"}, ] [package.extras] @@ -5196,7 +5196,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", markers = "python_version >= \"3\" and (platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\")"} +greenlet = {version = "!=0.4.17", markers = "python_version >= \"3\" and (platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\")"} [package.extras] aiomysql = ["aiomysql", "greenlet (!=0.4.17)"] @@ -6102,4 +6102,4 @@ docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings", "mkdocstr [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "fc501337bc125fba1e54091aa7aacff2f34d59e2ee79dd404af0269c2dd7f8b0" +content-hash = "3144caf401c63c6593ea414bc86ca0729a55799303e141cdb1baee365c745a02" diff --git a/pyproject.toml b/pyproject.toml index 120a082d..d270f02d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ mkdocstrings = { version = ">=0.19,<0.22", optional = true } mkdocstrings-python = { version = ">=0.7.1,<0.10.0", optional = true } duckdb-engine = "^0.9.2" exceptiongroup = "^1.0.4" -repr-llm = "^0.2.1" +repr-llm = "^0.3.0" structlog = "^23.2.0" [tool.poetry.group.dev.dependencies] diff --git a/src/dx/formatters/main.py b/src/dx/formatters/main.py index a2b1da21..56658da0 100644 --- a/src/dx/formatters/main.py +++ b/src/dx/formatters/main.py @@ -9,8 +9,8 @@ from IPython.core.interactiveshell import InteractiveShell from IPython.display import display as ipydisplay from pandas.io.json import build_table_schema -from repr_llm.pandas import summarize_dataframe +from dx.formatters.summarizing import make_df_summary from dx.sampling import get_column_string_lengths, get_df_dimensions, sample_if_too_big from dx.settings import get_settings from dx.types.main import DXDisplayMode @@ -216,7 +216,7 @@ def format_output( # add additional payload for LLM consumption; if any parsing/summarizing errors occur, we # shouldn't block displaying the bundle try: - payload["text/llm+plain"] = summarize_dataframe(df) + payload["text/llm+plain"] = make_df_summary(df) except Exception as e: logger.debug(f"Error in summarize_dataframe: {e}") diff --git a/src/dx/formatters/summarizing.py b/src/dx/formatters/summarizing.py new file mode 100644 index 00000000..bdda8e9f --- /dev/null +++ b/src/dx/formatters/summarizing.py @@ -0,0 +1,54 @@ +from typing import Callable, Optional + +import pandas as pd + + +class DataFrameSummarizer: + _instance: "DataFrameSummarizer" = None + summarizing_func: Optional[Callable] = None + + def __init__(self, summarizing_func: Optional[Callable] = None): + if summarizing_func is None: + self._try_to_load_repr_llm() + else: + self.summarizing_func = summarizing_func + + def _try_to_load_repr_llm(self) -> None: + """Load repr_llm's summarize_dataframe into the summarizing_func if it's available.""" + try: + from repr_llm.pandas import summarize_dataframe + + self.summarizing_func = summarize_dataframe + except ImportError: + return + + @classmethod + def instance(cls) -> "DataFrameSummarizer": + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def summarize(self, df: pd.DataFrame) -> str: + """Generate a summary of a dataframe using the configured summarizing_func.""" + if not isinstance(df, pd.DataFrame): + raise ValueError("`df` must be a pandas DataFrame") + + if self.summarizing_func is None: + return df.describe().to_string() + + return self.summarizing_func(df) + + +def get_summarizing_function() -> Optional[Callable]: + """Get the function to use for summarizing dataframes.""" + return DataFrameSummarizer.instance().summarizing_func + + +def set_summarizing_function(func: Callable) -> None: + """Set the function to use for summarizing dataframes.""" + DataFrameSummarizer.instance().summarizing_func = func + + +def make_df_summary(df: pd.DataFrame) -> str: + """Generate a summary of a dataframe using the configured summarizing_func.""" + return DataFrameSummarizer.instance().summarize(df)