Skip to content

Commit

Permalink
Merge pull request #3 from jepler/use-markdown-flavor
Browse files Browse the repository at this point in the history
Use markdown flavor
  • Loading branch information
jepler authored Apr 4, 2023
2 parents b34b752 + d94eb73 commit f81fc86
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ repos:
rev: v2.17.0
hooks:
- id: pylint
additional_dependencies: [click,dataclasses_json,httpx,lorem-text,textual,websockets]
additional_dependencies: [click,dataclasses_json,httpx,lorem-text,'textual>=0.18.0',websockets]
args: ['--source-roots', 'src']
- repo: https://github.com/pycqa/isort
rev: 5.12.0
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ dependencies = [
"httpx",
"lorem-text",
"platformdirs",
"textual",
"textual>=0.18.0",
"websockets",
]
classifiers = [
Expand Down
18 changes: 13 additions & 5 deletions src/chap/commands/tui.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys

import click
from markdown_it import MarkdownIt
from textual.app import App
from textual.binding import Binding
from textual.containers import Container
Expand All @@ -16,14 +17,21 @@
from ..session import Assistant, Session, User


def parser_factory():
parser = MarkdownIt()
parser.options["html"] = False
return parser


class Markdown(Markdown, can_focus=True): # pylint: disable=function-redefined
pass


def markdown_for_step(step):
return Markdown(
step.content.strip().replace("<", "&lt;"),
step.content.strip() or "…",
classes="role_" + step.role,
parser_factory=parser_factory,
)


Expand Down Expand Up @@ -73,7 +81,7 @@ async def on_input_submitted(self, event) -> None:
async def render_fun():
while await update.get():
if tokens:
await output.update("".join(tokens).replace("<", "&lt;"))
output.update("".join(tokens).strip())
self.container.scroll_end()
await asyncio.sleep(0.1)

Expand All @@ -90,10 +98,10 @@ async def get_token_fun():
await asyncio.gather(render_fun(), get_token_fun())
self.input.value = ""
finally:
all_output = self.session.session[-1].content.replace("<", "&lt;")
await output.update(all_output)
self.container.scroll_end()
all_output = self.session.session[-1].content
output.update(all_output)
output._markdown = all_output # pylint: disable=protected-access
self.container.scroll_end()
self.input.disabled = False

def scroll_end(self):
Expand Down

0 comments on commit f81fc86

Please sign in to comment.