Skip to content

Commit

Permalink
Merge pull request #13 from jepler/interactivity
Browse files Browse the repository at this point in the history
More Interactivity in the TUI
  • Loading branch information
jepler authored Sep 25, 2023
2 parents 1b700aa + 0f9c6f1 commit 9919c9a
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 20 deletions.
40 changes: 25 additions & 15 deletions src/chap/commands/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,8 @@

from ..core import command_uses_new_session

if sys.stdout.isatty():
bold = "\033[1m"
nobold = "\033[m"
else:
bold = nobold = ""
bold = "\033[1m"
nobold = "\033[m"


def ipartition(s, sep):
Expand All @@ -24,6 +21,14 @@ def ipartition(s, sep):
yield (first, opt_sep)


class DumbPrinter:
def raw(self, s):
pass

def add(self, s):
print(s, end="")


class WrappingPrinter:
def __init__(self, width=None):
self._width = width or rich.get_console().width
Expand Down Expand Up @@ -58,37 +63,42 @@ def add(self, s):
self._sp = ""


def verbose_ask(api, session, q, **kw):
printer = WrappingPrinter()
def verbose_ask(api, session, q, print_prompt, **kw):
if sys.stdout.isatty():
printer = WrappingPrinter()
else:
printer = DumbPrinter()
tokens = []

async def work():
async for token in api.aask(session, q, **kw):
printer.add(token)

printer.raw(bold)
printer.add(q)
printer.raw(nobold)
printer.add("\n")
printer.add("\n")
if print_prompt:
printer.raw(bold)
printer.add(q)
printer.raw(nobold)
printer.add("\n")
printer.add("\n")

asyncio.run(work())
printer.add("\n")
printer.add("\n")
result = "".join(tokens)
return result


@command_uses_new_session
@click.option("--print-prompt/--no-print-prompt", default=True)
@click.argument("prompt", nargs=-1, required=True)
def main(obj, prompt):
def main(obj, prompt, print_prompt):
"""Ask a question (command-line argument is passed as prompt)"""
session = obj.session
session_filename = obj.session_filename
api = obj.api

# symlink_session_filename(session_filename)

response = verbose_ask(api, session, " ".join(prompt))
response = verbose_ask(api, session, " ".join(prompt), print_prompt=print_prompt)

print(f"Saving session to {session_filename}", file=sys.stderr)
if response is not None:
Expand Down
46 changes: 41 additions & 5 deletions src/chap/commands/tui.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from textual.containers import Container, VerticalScroll
from textual.widgets import Footer, Input, Markdown

from ..core import command_uses_new_session, get_api
from ..core import command_uses_new_session, get_api, new_session_path
from ..session import Assistant, Session, User


Expand All @@ -22,8 +22,14 @@ def parser_factory():
return parser


class Markdown(Markdown, can_focus=True): # pylint: disable=function-redefined
pass
class Markdown(
Markdown, can_focus=True, can_focus_children=False
): # pylint: disable=function-redefined
BINDINGS = [
Binding("ctrl+y", "yank", "Yank text", show=True),
Binding("ctrl+r", "resubmit", "resubmit", show=True),
Binding("ctrl+x", "delete", "delete to end", show=True),
]


def markdown_for_step(step):
Expand All @@ -37,8 +43,7 @@ def markdown_for_step(step):
class Tui(App):
CSS_PATH = "tui.css"
BINDINGS = [
Binding("ctrl+y", "yank", "Yank text", show=True),
Binding("ctrl+q", "app.quit", "Quit", show=True),
Binding("ctrl+q", "app.quit", "Quit", show=True, priority=True),
]

def __init__(self, api=None, session=None):
Expand Down Expand Up @@ -113,6 +118,37 @@ def action_yank(self):
content = widget._markdown # pylint: disable=protected-access
subprocess.run(["xsel", "-ib"], input=content.encode("utf-8"), check=False)

async def action_resubmit(self):
await self.delete_or_resubmit(True)

async def action_delete(self):
await self.delete_or_resubmit(False)

async def delete_or_resubmit(self, resubmit):
widget = self.focused
if not isinstance(widget, Markdown):
return
children = self.container.children
idx = children.index(widget)
while idx > 1 and not "role_user" in children[idx].classes:
idx -= 1
widget = children[idx]

# Save a copy of the discussion before this deletion
with open(new_session_path(), "w", encoding="utf-8") as f:
f.write(self.session.to_json())

query = self.session.session[idx].content
self.input.value = query

del self.session.session[idx:]
for child in self.container.children[idx:-1]:
await child.remove()

self.input.focus()
if resubmit:
await self.input.action_submit()


@command_uses_new_session
def main(obj):
Expand Down
3 changes: 3 additions & 0 deletions src/chap/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ def colonstr(arg):


def set_system_message(ctx, param, value): # pylint: disable=unused-argument
if value and value.startswith("@"):
with open(value[1:], "r", encoding="utf-8") as f:
value = f.read().rstrip()
ctx.obj.system_message = value


Expand Down

0 comments on commit 9919c9a

Please sign in to comment.