diff --git a/src/chap/commands/ask.py b/src/chap/commands/ask.py index e426160..cdfb45c 100644 --- a/src/chap/commands/ask.py +++ b/src/chap/commands/ask.py @@ -10,11 +10,8 @@ from ..core import command_uses_new_session -if sys.stdout.isatty(): - bold = "\033[1m" - nobold = "\033[m" -else: - bold = nobold = "" +bold = "\033[1m" +nobold = "\033[m" def ipartition(s, sep): @@ -24,6 +21,14 @@ def ipartition(s, sep): yield (first, opt_sep) +class DumbPrinter: + def raw(self, s): + pass + + def add(self, s): + print(s, end="") + + class WrappingPrinter: def __init__(self, width=None): self._width = width or rich.get_console().width @@ -58,29 +63,34 @@ def add(self, s): self._sp = "" -def verbose_ask(api, session, q, **kw): - printer = WrappingPrinter() +def verbose_ask(api, session, q, print_prompt, **kw): + if sys.stdout.isatty(): + printer = WrappingPrinter() + else: + printer = DumbPrinter() tokens = [] async def work(): async for token in api.aask(session, q, **kw): printer.add(token) - printer.raw(bold) - printer.add(q) - printer.raw(nobold) - printer.add("\n") - printer.add("\n") + if print_prompt: + printer.raw(bold) + printer.add(q) + printer.raw(nobold) + printer.add("\n") + printer.add("\n") + asyncio.run(work()) printer.add("\n") - printer.add("\n") result = "".join(tokens) return result @command_uses_new_session +@click.option("--print-prompt/--no-print-prompt", default=True) @click.argument("prompt", nargs=-1, required=True) -def main(obj, prompt): +def main(obj, prompt, print_prompt): """Ask a question (command-line argument is passed as prompt)""" session = obj.session session_filename = obj.session_filename @@ -88,7 +98,7 @@ def main(obj, prompt): # symlink_session_filename(session_filename) - response = verbose_ask(api, session, " ".join(prompt)) + response = verbose_ask(api, session, " ".join(prompt), print_prompt=print_prompt) print(f"Saving session to {session_filename}", file=sys.stderr) if response is not None: diff --git a/src/chap/commands/tui.py b/src/chap/commands/tui.py index d38e045..33f71d4 100644 --- a/src/chap/commands/tui.py +++ b/src/chap/commands/tui.py @@ -12,7 +12,7 @@ from textual.containers import Container, VerticalScroll from textual.widgets import Footer, Input, Markdown -from ..core import command_uses_new_session, get_api +from ..core import command_uses_new_session, get_api, new_session_path from ..session import Assistant, Session, User @@ -22,8 +22,14 @@ def parser_factory(): return parser -class Markdown(Markdown, can_focus=True): # pylint: disable=function-redefined - pass +class Markdown( + Markdown, can_focus=True, can_focus_children=False +): # pylint: disable=function-redefined + BINDINGS = [ + Binding("ctrl+y", "yank", "Yank text", show=True), + Binding("ctrl+r", "resubmit", "resubmit", show=True), + Binding("ctrl+x", "delete", "delete to end", show=True), + ] def markdown_for_step(step): @@ -37,8 +43,7 @@ def markdown_for_step(step): class Tui(App): CSS_PATH = "tui.css" BINDINGS = [ - Binding("ctrl+y", "yank", "Yank text", show=True), - Binding("ctrl+q", "app.quit", "Quit", show=True), + Binding("ctrl+q", "app.quit", "Quit", show=True, priority=True), ] def __init__(self, api=None, session=None): @@ -113,6 +118,37 @@ def action_yank(self): content = widget._markdown # pylint: disable=protected-access subprocess.run(["xsel", "-ib"], input=content.encode("utf-8"), check=False) + async def action_resubmit(self): + await self.delete_or_resubmit(True) + + async def action_delete(self): + await self.delete_or_resubmit(False) + + async def delete_or_resubmit(self, resubmit): + widget = self.focused + if not isinstance(widget, Markdown): + return + children = self.container.children + idx = children.index(widget) + while idx > 1 and not "role_user" in children[idx].classes: + idx -= 1 + widget = children[idx] + + # Save a copy of the discussion before this deletion + with open(new_session_path(), "w", encoding="utf-8") as f: + f.write(self.session.to_json()) + + query = self.session.session[idx].content + self.input.value = query + + del self.session.session[idx:] + for child in self.container.children[idx:-1]: + await child.remove() + + self.input.focus() + if resubmit: + await self.input.action_submit() + @command_uses_new_session def main(obj): diff --git a/src/chap/core.py b/src/chap/core.py index a41dd61..44c93b3 100644 --- a/src/chap/core.py +++ b/src/chap/core.py @@ -106,6 +106,9 @@ def colonstr(arg): def set_system_message(ctx, param, value): # pylint: disable=unused-argument + if value and value.startswith("@"): + with open(value[1:], "r", encoding="utf-8") as f: + value = f.read().rstrip() ctx.obj.system_message = value