Skip to content

Commit

Permalink
Merge pull request #100 from rgbkrk/ruff
Browse files Browse the repository at this point in the history
Ruff
  • Loading branch information
rgbkrk authored Oct 26, 2023
2 parents 087b02a + b8bfc5e commit 016af40
Show file tree
Hide file tree
Showing 23 changed files with 406 additions and 223 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name: Ruff
on: [ push, pull_request ]
jobs:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: chartboost/ruff-action@v1
27 changes: 23 additions & 4 deletions chatlab/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,24 @@
"""

__author__ = """Kyle Kelley"""
__email__ = '[email protected]'
__email__ = "[email protected]"

from deprecation import deprecated

from . import models
from ._version import __version__
from .conversation import Chat
from .decorators import ChatlabMetadata, expose_exception_to_llm
from .messaging import ai, assistant, assistant_function_call, function_result, human, narrate, system, user
from .messaging import (
ai,
assistant,
assistant_function_call,
function_result,
human,
narrate,
system,
user,
)
from .registry import FunctionRegistry
from .views.markdown import Markdown

Expand All @@ -33,7 +42,12 @@ class Session(Chat):
Session is deprecated. Use `Chat` instead.
"""

@deprecated(deprecated_in="0.13.0", removed_in="1.0.0", current_version=__version__, details="Use `Chat` instead.")
@deprecated(
deprecated_in="0.13.0",
removed_in="1.0.0",
current_version=__version__,
details="Use `Chat` instead.",
)
def __init__(self, *args, **kwargs):
"""Initialize a Session with an optional initial context of messages.
Expand All @@ -48,7 +62,12 @@ class Conversation(Chat):
Conversation is deprecated. Use `Chat` instead.
"""

@deprecated(deprecated_in="1.0.0", removed_in="1.1.0", current_version=__version__, details="Use `Chat` instead.")
@deprecated(
deprecated_in="1.0.0",
removed_in="1.1.0",
current_version=__version__,
details="Use `Chat` instead.",
)
def __init__(self, *args, **kwargs):
"""Initialize a Session with an optional initial context of messages.
Expand Down
2 changes: 1 addition & 1 deletion chatlab/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.0.0-alpha.27'
__version__ = "1.0.0-alpha.27"
9 changes: 8 additions & 1 deletion chatlab/builtins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,11 @@
# compose all the file, shell, and python functions into one list for ease of use
os_functions = file_functions + shell_functions + [run_python, get_python_docs]

__all__ = ["run_python", "get_python_docs", "run_cell", "file_functions", "shell_functions", "os_functions"]
__all__ = [
"run_python",
"get_python_docs",
"run_cell",
"file_functions",
"shell_functions",
"os_functions",
]
48 changes: 24 additions & 24 deletions chatlab/builtins/_mediatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@
# Prioritized formats to show to large language models
formats_for_llm = [
# Repr LLM is the richest text
'text/llm+plain',
"text/llm+plain",
# Assume that if we get markdown we know it's rich for an LLM
'text/markdown',
"text/markdown",
# Same with LaTeX
'text/latex',
"text/latex",
# All the normal ones
'application/vnd.jupyter.error+json',
"application/vnd.jupyter.error+json",
# 'application/vdom.v1+json',
'application/json',
"application/json",
# Since every object has a text/plain repr, even though the LLM would understand `text/plain` well,
# bumping this priority up would override more rich media types we definitely want to show.
'text/plain',
"text/plain",
# Both HTML and SVG should be conditional on size, considering many libraries
# Will emit giant JavaScript blobs for interactivity
# For now, we'll assume not to show these
Expand All @@ -29,20 +29,20 @@

# Prioritized formats to redisplay for the user, since we capture the output during execution
formats_to_redisplay = [
'application/vnd.jupyter.widget-view+json',
'application/vnd.dex.v1+json',
'application/vnd.dataresource+json',
'application/vnd.plotly.v1+json',
'text/vnd.plotly.v1+html',
'application/vdom.v1+json',
'application/json',
'application/javascript',
'image/svg+xml',
'image/png',
'image/jpeg',
'image/gif',
'text/html',
'image/svg+xml',
"application/vnd.jupyter.widget-view+json",
"application/vnd.dex.v1+json",
"application/vnd.dataresource+json",
"application/vnd.plotly.v1+json",
"text/vnd.plotly.v1+html",
"application/vdom.v1+json",
"application/json",
"application/javascript",
"image/svg+xml",
"image/png",
"image/jpeg",
"image/gif",
"text/html",
"image/svg+xml",
]


Expand All @@ -63,14 +63,14 @@ def redisplay_superrich(output: RichOutput):
metadata.pop(richest_format, None)

# Check to see if it already has a text/llm+plain representation
if 'text/llm+plain' in data:
if "text/llm+plain" in data:
return

if richest_format.startswith('image/'):
if richest_format.startswith("image/"):
# Allow the LLM to see that we displayed for the user
data['text/llm+plain'] = data['text/plain']
data["text/llm+plain"] = data["text/plain"]
else:
data['text/llm+plain'] = f"<Displayed {richest_format}>"
data["text/llm+plain"] = f"<Displayed {richest_format}>"


def pluck_richest_text(output: RichOutput):
Expand Down
8 changes: 4 additions & 4 deletions chatlab/builtins/colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,24 @@ def colors(self):
@colors.setter
def colors(self, colors: List[str]):
self._colors = colors
self.html = '<div>'
self.html = "<div>"
for color in colors:
self.html += f'<div style="background-color:{color}; width:50px; height:50px; display:inline-block;"></div>'
self.html += '</div>'
self.html += "</div>"

def _repr_html_(self):
return self.html

def __repr__(self):
"""Returns a string representation of the palette."""
return f'Palette({self.colors}, {self.name})'
return f"Palette({self.colors}, {self.name})"


palettes: Dict[str, Palette] = {}


def _generate_palette_name(colors: List[str]) -> str:
hash_object = hashlib.sha1(''.join(colors).encode())
hash_object = hashlib.sha1("".join(colors).encode())
return f"palette-{hash_object.hexdigest()}"


Expand Down
13 changes: 10 additions & 3 deletions chatlab/builtins/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async def is_file(file_path: str) -> bool:


@expose_exception_to_llm
async def write_file(file_path: str, content: str, mode: str = 'w') -> None:
async def write_file(file_path: str, content: str, mode: str = "w") -> None:
"""Write content to a file.
Args:
Expand All @@ -75,7 +75,7 @@ async def write_file(file_path: str, content: str, mode: str = 'w') -> None:


@expose_exception_to_llm
async def read_file(file_path: str, mode: str = 'r') -> str:
async def read_file(file_path: str, mode: str = "r") -> str:
"""Read content from a file.
Args:
Expand Down Expand Up @@ -104,4 +104,11 @@ async def is_directory(directory: str) -> bool:
return is_directory


chat_functions = [list_files, get_file_size, is_file, is_directory, write_file, read_file]
chat_functions = [
list_files,
get_file_size,
is_file,
is_directory,
write_file,
read_file,
]
47 changes: 33 additions & 14 deletions chatlab/builtins/noteable.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ async def create(cls, file_name=None, token=None, file_id=None, project_id=None)
# We have to track the kernel_session for now
kernel_session = await api_client.launch_kernel(file_id)

cn = NotebookClient(api_client, rtu_client, file_id=file_id, kernel_session=kernel_session)
cn = NotebookClient(
api_client, rtu_client, file_id=file_id, kernel_session=kernel_session
)

return cn

Expand Down Expand Up @@ -165,7 +167,9 @@ async def create_cell(
)

if cell is None:
return f"Unknown cell type {cell_type}. Valid types are: markdown, code, sql."
return (
f"Unknown cell type {cell_type}. Valid types are: markdown, code, sql."
)

logger.info(f"Adding cell {cell_id} to notebook")
cell = await rtu_client.add_cell(cell=cell, after_id=after_cell_id)
Expand Down Expand Up @@ -197,7 +201,7 @@ async def update_cell(
_, cell = rtu_client.builder.get_cell(cell_id)

# Pull the old cell type, as long as it's not a "raw" cell
if cell_type is None and cell.cell_type != 'raw':
if cell_type is None and cell.cell_type != "raw":
cell_type = cell.cell_type

if cell_type is not None:
Expand All @@ -209,7 +213,10 @@ async def update_cell(
conn = f"@{conn}"

await rtu_client.change_cell_type(
cell_id, cell_type, db_connection=conn, assign_results_to=assign_results_to
cell_id,
cell_type,
db_connection=conn,
assign_results_to=assign_results_to,
)
except Exception as e:
return f"Error replacing cell content: {e}"
Expand All @@ -223,7 +230,9 @@ async def update_cell(

async def _get_llm_friendly_outputs(self, output_collection_id: uuid.UUID):
"""Get the outputs for a given output collection ID."""
output_collection = await self.api_client.get_output_collection(output_collection_id)
output_collection = await self.api_client.get_output_collection(
output_collection_id
)

outputs = output_collection.outputs

Expand All @@ -243,7 +252,9 @@ async def _get_llm_friendly_outputs(self, output_collection_id: uuid.UUID):
return llm_friendly_outputs

async def _extract_llm_plain(self, output: KernelOutput):
resp = await self.api_client.client.get(f"/outputs/{output.id}", params={"mimetype": "text/llm+plain"})
resp = await self.api_client.client.get(
f"/outputs/{output.id}", params={"mimetype": "text/llm+plain"}
)
resp.raise_for_status()

output_for_llm = KernelOutput.parse_obj(resp.json())
Expand All @@ -254,7 +265,9 @@ async def _extract_llm_plain(self, output: KernelOutput):
return output_for_llm.content.raw

async def _extract_specific_mediatype(self, output: KernelOutput, mimetype: str):
resp = await self.api_client.client.get(f"/outputs/{output.id}", params={"mimetype": mimetype})
resp = await self.api_client.client.get(
f"/outputs/{output.id}", params={"mimetype": mimetype}
)
resp.raise_for_status()

output_for_llm = KernelOutput.parse_obj(resp.json())
Expand Down Expand Up @@ -296,16 +309,16 @@ async def _get_llm_friendly_output(self, output: KernelOutput):
if output.type == "error":
return await self._extract_error(content)

if content.mimetype == 'text/html':
result = await self._extract_specific_mediatype(output, 'text/plain')
if content.mimetype == "text/html":
result = await self._extract_specific_mediatype(output, "text/plain")
if result is not None:
return result

if content.mimetype == 'application/vnd.dataresource+json':
if content.mimetype == "application/vnd.dataresource+json":
# TODO: Bring back a smaller representation to allow the LLM to do analysis
return "<!-- DataFrame shown in notebook that user can see -->"

if content.mimetype == 'application/vnd.plotly.v1+json':
if content.mimetype == "application/vnd.plotly.v1+json":
return "<!-- Plotly shown in notebook that user can see -->"
if content.url is not None:
return "<!-- Large output too large for chat. It is available in the notebook that the user can see -->"
Expand All @@ -317,7 +330,9 @@ async def _get_llm_friendly_output(self, output: KernelOutput):

for format in formats_for_llm:
if format in mimetypes:
resp = await self.api_client.client.get(f"/outputs/{output.id}?mimetype={format}")
resp = await self.api_client.client.get(
f"/outputs/{output.id}?mimetype={format}"
)
resp.raise_for_status()
if resp.status_code == 200:
return
Expand Down Expand Up @@ -406,7 +421,9 @@ async def get_cell(self, cell_id: str, with_outputs: bool = True):
response += cell.source
return response

source_type = rtu_client.builder.nb.metadata.get("kernelspec", {}).get("language", "")
source_type = rtu_client.builder.nb.metadata.get("kernelspec", {}).get(
"language", ""
)

if cell.metadata.get("noteable", {}).get("cell_type") == "sql":
source_type = "sql"
Expand Down Expand Up @@ -526,7 +543,9 @@ def chat_functions(self):
]


def provide_notebook_creation(registry: FunctionRegistry, project_id: Optional[str] = None):
def provide_notebook_creation(
registry: FunctionRegistry, project_id: Optional[str] = None
):
"""Register the notebook client with the registry.
>>> from chatlab import FunctionRegistry, Chat
Expand Down
Loading

0 comments on commit 016af40

Please sign in to comment.