Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions litecli/clistyle.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from __future__ import annotations

import logging

from typing import cast

import pygments.styles
from pygments.token import string_to_tokentype, Token
from pygments.style import Style as PygmentsStyle
from pygments.util import ClassNotFound
from prompt_toolkit.styles import Style, merge_styles
from prompt_toolkit.styles.pygments import style_from_pygments_cls
from prompt_toolkit.styles import merge_styles, Style
from prompt_toolkit.styles.style import _MergedStyle
from pygments.style import Style as PygmentsStyle
from pygments.token import Token, _TokenType, string_to_tokentype
from pygments.util import ClassNotFound

logger = logging.getLogger(__name__)

# map Pygments tokens (ptk 1.0) to class names (ptk 2.0).
TOKEN_TO_PROMPT_STYLE: dict[Token, str] = {
TOKEN_TO_PROMPT_STYLE: dict[_TokenType, str] = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mapping to accurate type so ty doesn't complain.

Token.Menu.Completions.Completion.Current: "completion-menu.completion.current",
Token.Menu.Completions.Completion: "completion-menu.completion",
Token.Menu.Completions.Meta.Current: "completion-menu.meta.completion.current",
Expand Down Expand Up @@ -43,10 +43,10 @@
}

# reverse dict for cli_helpers, because they still expect Pygments tokens.
PROMPT_STYLE_TO_TOKEN: dict[str, Token] = {v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()}
PROMPT_STYLE_TO_TOKEN: dict[str, _TokenType] = {v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()}


def parse_pygments_style(token_name: str, style_object: PygmentsStyle | dict, style_dict: dict[str, str]) -> tuple[Token, str]:
def parse_pygments_style(token_name: str, style_object: PygmentsStyle | dict, style_dict: dict[str, str]) -> tuple[_TokenType, str]:
"""Parse token type and style string.

:param token_name: str name of Pygments token. Example: "Token.String"
Expand Down Expand Up @@ -111,4 +111,5 @@ class OutputStyle(PygmentsStyle):
default_style = ""
styles = style

return OutputStyle
# mypy does not complain only ty complains: error[invalid-return-type]: Return type does not match returned value. Hence added cast.
return cast(OutputStyle, PygmentsStyle)
5 changes: 2 additions & 3 deletions litecli/completion_refresher.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from __future__ import annotations

import threading
from collections import OrderedDict
from typing import Callable

from .packages.special.main import COMMANDS
from collections import OrderedDict

from .sqlcompleter import SQLCompleter
from .sqlexecute import SQLExecute

Expand Down Expand Up @@ -94,7 +93,7 @@ def _bg_refresh(
# break statement.
continue

for callback in callbacks:
for callback in callbacks: # ty: ignore[not-iterable]
callback(completer)


Expand Down
7 changes: 3 additions & 4 deletions litecli/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from __future__ import annotations

import errno
import shutil
import os
import platform
from os.path import expanduser, exists, dirname

import shutil
from os.path import dirname, exists, expanduser

from configobj import ConfigObj

Expand Down Expand Up @@ -55,7 +54,7 @@ def upgrade_config(config: str, def_config: str) -> None:
def get_config(liteclirc_file: str | None = None) -> ConfigObj:
from litecli import __file__ as package_root

package_root = os.path.dirname(package_root)
package_root = os.path.dirname(str(package_root))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ty doesn't like package_root as it comes from import reference, hence converting to string.


liteclirc_file = liteclirc_file or f"{config_location()}config"

Expand Down
21 changes: 11 additions & 10 deletions litecli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@
from io import open

try:
from sqlean import OperationalError, sqlite_version
from sqlean import OperationalError, sqlite_version # ty: ignore[unresolved-import]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sqlean has no type definitions, hence need to ignore all places. At runtime, the methods are available.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe create types for sqlean.

except ImportError:
from sqlite3 import OperationalError, sqlite_version
from time import time
from typing import Any, Iterable
from typing import Any, Iterable, cast

import click
import sqlparse
from cli_helpers.tabular_output import TabularOutputFormatter, preprocessors
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
from prompt_toolkit.completion import DynamicCompleter
from prompt_toolkit.completion import Completion, DynamicCompleter
from prompt_toolkit.document import Document
from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode
from prompt_toolkit.filters import HasFocus, IsDone
Expand All @@ -35,8 +35,6 @@
)
from prompt_toolkit.lexers import PygmentsLexer
from prompt_toolkit.shortcuts import CompleteStyle, PromptSession
from typing import cast
from prompt_toolkit.completion import Completion

from .__init__ import __version__
from .clibuffer import cli_is_multiline
Expand All @@ -53,8 +51,6 @@
from .sqlcompleter import SQLCompleter
from .sqlexecute import SQLExecute

click.disable_unicode_literals_warning = True
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is no longer needed in Python 3.9


# Query tuples are used for maintaining history
Query = namedtuple("Query", ["query", "successful", "mutating"])

Expand Down Expand Up @@ -84,7 +80,7 @@ def __init__(
self.key_bindings = c["main"]["key_bindings"]
special.set_favorite_queries(self.config)
self.formatter = TabularOutputFormatter(format_name=c["main"]["table_format"])
self.formatter.litecli = self
self.formatter.litecli = self # ty: ignore[unresolved-attribute]
self.syntax_style = c["main"]["syntax_style"]
self.less_chatty = c["main"].as_bool("less_chatty")
self.show_bottom_toolbar = c["main"].as_bool("show_bottom_toolbar")
Expand Down Expand Up @@ -200,11 +196,14 @@ def change_db(self, arg: str | None, **_: Any) -> Iterable[tuple]:
self.sqlexecute.connect(database=arg)

self.refresh_completions()
# guard so that ty doesn't complain
dbname = self.sqlexecute.dbname if self.sqlexecute is not None else ""

yield (
None,
None,
None,
'You are now connected to database "%s"' % (self.sqlexecute.dbname),
'You are now connected to database "%s"' % (dbname),
)

def execute_from_file(self, arg: str | None, **_: Any) -> Iterable[tuple[Any, ...]]:
Expand Down Expand Up @@ -522,7 +521,8 @@ def one_iteration(text: str | None = None) -> None:
raise e
except KeyboardInterrupt:
try:
sqlexecute.conn.interrupt()
# since connection can sqlite3 or sqlean, it's hard to annotate the type for interrupt. so ignore the type hint warning.
sqlexecute.conn.interrupt() # ty: ignore[possibly-missing-attribute]
except Exception as e:
self.echo(
"Encountered error while cancelling query: {}".format(e),
Expand Down Expand Up @@ -755,6 +755,7 @@ def refresh_completions(self, reset: bool = False) -> list[tuple]:
if reset:
with self._completer_lock:
self.completer.reset_completions()
assert self.sqlexecute is not None
self.completion_refresher.refresh(
self.sqlexecute,
self._on_completions_refreshed,
Expand Down
10 changes: 5 additions & 5 deletions litecli/packages/parseutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from typing import Generator, Iterable, Literal

import sqlparse
from sqlparse.sql import IdentifierList, Identifier, Function, Token, TokenList
from sqlparse.tokens import Keyword, DML, Punctuation
from sqlparse.sql import Function, Identifier, IdentifierList, Token, TokenList
from sqlparse.tokens import DML, Keyword, Punctuation

cleanup_regex: dict[str, re.Pattern[str]] = {
# This matches only alphanumerics and underscores.
Expand All @@ -18,10 +18,10 @@
"all_punctuations": re.compile(r"([^\s]+)$"),
}

LAST_WORD_INCLUDE_TYPE = Literal["alphanum_underscore", "many_punctuations", "most_punctuations", "all_punctuations"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ty later complains the value passed is str and so declare and import later.


def last_word(
text: str, include: Literal["alphanum_underscore", "many_punctuations", "most_punctuations", "all_punctuations"] = "alphanum_underscore"
) -> str:

def last_word(text: str, include: LAST_WORD_INCLUDE_TYPE = "alphanum_underscore") -> str:
R"""
Find the last word in a sentence.
Expand Down
28 changes: 26 additions & 2 deletions litecli/packages/special/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# ruff: noqa

from __future__ import annotations
from types import FunctionType

from typing import Callable, Any

Expand All @@ -9,11 +10,34 @@

def export(defn: Callable[..., Any]) -> Callable[..., Any]:
"""Decorator to explicitly mark functions that are exposed in a lib."""
globals()[defn.__name__] = defn
__all__.append(defn.__name__)
# ty, requires explict check for callable of tyep | function type to access __name__
if isinstance(defn, (type, FunctionType)):
globals()[defn.__name__] = defn
__all__.append(defn.__name__)
return defn


from . import dbcommands
from . import iocommands
from . import llm
from . import utils
from .main import CommandNotFound, register_special_command, execute
from .iocommands import (
set_favorite_queries,
editor_command,
get_filename,
get_editor_query,
open_external_editor,
is_expanded_output,
set_expanded_output,
write_tee,
unset_once_if_written,
unset_pipe_once_if_written,
disable_pager,
set_pager,
is_pager_enabled,
write_once,
write_pipe_once,
close_tee,
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ty doesn't like when these methods are accessed like special.xyz()when it is missing in the init.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There may be scope to improve over all import systems.

from .llm import is_llm_command, handle_llm, FinishIteration
3 changes: 2 additions & 1 deletion litecli/packages/special/favoritequeries.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

import builtins
from typing import Any, cast


Expand Down Expand Up @@ -39,7 +40,7 @@ class FavoriteQueries(object):
def __init__(self, config: Any) -> None:
self.config = config

def list(self) -> list[str]:
def list(self) -> builtins.list[str]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ty is opinionated about this. astral-sh/ty#2035

section = cast(dict[str, str], self.config.get(self.section_name, {}))
return list(section.keys())

Expand Down
3 changes: 2 additions & 1 deletion litecli/packages/special/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
LLM_TEMPLATE_NAME = "litecli-llm-template"
LLM_CLI_COMMANDS: list[str] = list(cli.commands.keys())
# Mapping of model_id to None used for completion tree leaves.
MODELS: dict[str, None] = {x.model_id: None for x in llm.get_models()}
# the file name is llm.py and module name is llm, hence ty is complaining that get_models is missing.
MODELS: dict[str, None] = {x.model_id: None for x in llm.get_models()} # ty: ignore[unresolved-attribute]


def run_external_cmd(
Expand Down
12 changes: 6 additions & 6 deletions litecli/sqlcompleter.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from __future__ import annotations

import logging
from re import compile, escape
from collections import Counter
from re import compile, escape
from typing import Any, Collection, Generator, Iterable, Literal, Sequence

from prompt_toolkit.completion import CompleteEvent, Completer, Completion
from prompt_toolkit.completion.base import Document

from .packages.completion_engine import suggest_type
from .packages.parseutils import last_word
from .packages.special.iocommands import favoritequeries
from .packages.filepaths import complete_path, parse_path, suggest_path
from .packages.parseutils import LAST_WORD_INCLUDE_TYPE, last_word
from .packages.special import llm
from .packages.filepaths import parse_path, complete_path, suggest_path
from .packages.special.iocommands import favoritequeries

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -381,7 +381,7 @@ def extend_functions(self, func_data: Iterable[Sequence[str]]) -> None:
metadata[self.dbname][func[0]] = None
self.all_completions.add(func[0])

def set_dbname(self, dbname: str) -> None:
def set_dbname(self, dbname: str | None) -> None:
self.dbname = dbname

def reset_completions(self) -> None:
Expand All @@ -397,7 +397,7 @@ def find_matches(
start_only: bool = False,
fuzzy: bool = True,
casing: str | None = None,
punctuations: str = "most_punctuations",
punctuations: LAST_WORD_INCLUDE_TYPE = "most_punctuations",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the puncutations value is literal and here it was declared as str, ty complains about it. hence use the same type here.

) -> Generator[Completion, None, None]:
"""Find completion matches for the given text.

Expand Down
16 changes: 8 additions & 8 deletions litecli/sqlexecute.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
from __future__ import annotations

import logging
from typing import Any, Generator, Iterable

from contextlib import closing
from typing import Any, Generator, Iterable

try:
import sqlean as sqlite3
from sqlean import OperationalError
from sqlean import OperationalError # ty: ignore[unresolved-import]

sqlite3.extensions.enable_all()
except ImportError:
import sqlite3
from sqlite3 import OperationalError
from litecli.packages.special.utils import check_if_sqlitedotcommand

import sqlparse
import os.path
from urllib.parse import urlparse

from .packages import special
import sqlparse

from litecli.packages import special
from litecli.packages.special.utils import check_if_sqlitedotcommand

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -88,7 +87,8 @@ def connect(self, database: str | None = None) -> None:
if not os.path.exists(db_dir_name):
raise Exception("Path does not exist: {}".format(db_dir_name))

conn = sqlite3.connect(database=db_name, isolation_level=None, uri=uri)
# sqlean exposes the connect method during run-time
conn = sqlite3.connect(database=db_name, isolation_level=None, uri=uri) # ty: ignore[possibly-missing-attribute]
conn.text_factory = lambda x: x.decode("utf-8", "backslashreplace")
if self.conn:
self.conn.close()
Expand Down
21 changes: 20 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ dependencies = [
"prompt-toolkit>=3.0.3,<4.0.0",
"pygments>=1.6",
"sqlparse>=0.4.4",
"setuptools", # Required by llm commands to install models
"setuptools", # Required by llm commands to install models
"pip",
"llm>=0.25.0",
"ty>=0.0.4",
]

[build-system]
Expand Down Expand Up @@ -87,3 +88,21 @@ exclude = [
'^\.pytest_cache/',
'^\.ruff_cache/',
]


[tool.ty.environment]
python-version = "3.9"
root = [".", "litecli", "litecli/packages", "litecli/packages/special"]


[tool.ty.src]
exclude = [
'**/build/',
'**/dist/',
'**/.tox/',
'**/.venv/',
'**/.mypy_cache/',
'**/.pytest_cache/',
'**/.ruff_cache/',
'tests/**'
]
Loading
Loading