diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 528f639..987ddde 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,7 +26,7 @@ repos: rev: v2.17.0 hooks: - id: pylint - additional_dependencies: [click,dataclasses_json,httpx,lorem-text,textual,websockets] + additional_dependencies: [click,dataclasses_json,httpx,lorem-text,'textual>=0.18.0',websockets] args: ['--source-roots', 'src'] - repo: https://github.com/pycqa/isort rev: 5.12.0 diff --git a/pyproject.toml b/pyproject.toml index a103a4e..1460525 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ "httpx", "lorem-text", "platformdirs", - "textual", + "textual>=0.18.0", "websockets", ] classifiers = [ diff --git a/src/chap/commands/tui.py b/src/chap/commands/tui.py index feb7231..d10c44c 100644 --- a/src/chap/commands/tui.py +++ b/src/chap/commands/tui.py @@ -7,6 +7,7 @@ import sys import click +from markdown_it import MarkdownIt from textual.app import App from textual.binding import Binding from textual.containers import Container @@ -16,14 +17,21 @@ from ..session import Assistant, Session, User +def parser_factory(): + parser = MarkdownIt() + parser.options["html"] = False + return parser + + class Markdown(Markdown, can_focus=True): # pylint: disable=function-redefined pass def markdown_for_step(step): return Markdown( - step.content.strip().replace("<", "<"), + step.content.strip() or "…", classes="role_" + step.role, + parser_factory=parser_factory, ) @@ -73,7 +81,7 @@ async def on_input_submitted(self, event) -> None: async def render_fun(): while await update.get(): if tokens: - await output.update("".join(tokens).replace("<", "<")) + output.update("".join(tokens).strip()) self.container.scroll_end() await asyncio.sleep(0.1) @@ -90,10 +98,10 @@ async def get_token_fun(): await asyncio.gather(render_fun(), get_token_fun()) self.input.value = "" finally: - all_output = self.session.session[-1].content.replace("<", "<") - await output.update(all_output) - self.container.scroll_end() + all_output = self.session.session[-1].content + output.update(all_output) output._markdown = all_output # pylint: disable=protected-access + self.container.scroll_end() self.input.disabled = False def scroll_end(self):