Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
Upcoming (TBD)
==============

Features
--------

* Add enum value completions for WHERE/HAVING clauses. (#790)


1.43.1 (2026/01/03)
==============

Expand Down
5 changes: 5 additions & 0 deletions mycli/completion_refresher.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ def refresh_tables(completer: SQLCompleter, executor: SQLExecute) -> None:
completer.extend_columns(table_columns_dbresult, kind="tables")


@refresher("enum_values")
def refresh_enum_values(completer: SQLCompleter, executor: SQLExecute) -> None:
completer.extend_enum_values(executor.enum_values())


@refresher("users")
def refresh_users(completer: SQLCompleter, executor: SQLExecute) -> None:
completer.extend_users(executor.users())
Expand Down
68 changes: 63 additions & 5 deletions mycli/packages/completion_engine.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from typing import Any

import sqlparse
Expand All @@ -6,6 +7,56 @@
from mycli.packages.parseutils import extract_tables, find_prev_keyword, last_word
from mycli.packages.special.main import parse_special_command

_ENUM_VALUE_RE = re.compile(
r"(?P<lhs>(?:`[^`]+`|[\w$]+)(?:\.(?:`[^`]+`|[\w$]+))?)\s*=\s*$",
re.IGNORECASE,
)


def _enum_value_suggestion(text_before_cursor: str, full_text: str) -> dict[str, Any] | None:
match = _ENUM_VALUE_RE.search(text_before_cursor)
if not match:
return None
if _is_inside_quotes(text_before_cursor, match.start("lhs")):
return None

lhs = match.group("lhs")
if "." in lhs:
parent, column = lhs.split(".", 1)
else:
parent, column = None, lhs

return {
"type": "enum_value",
"tables": extract_tables(full_text),
"column": column,
"parent": parent,
}


def _is_where_or_having(token: Token | None) -> bool:
return bool(token and token.value and token.value.lower() in ("where", "having"))


def _is_inside_quotes(text: str, pos: int) -> bool:
in_single = False
in_double = False
escaped = False

for ch in text[:pos]:
if escaped:
escaped = False
continue
if ch == "\\":
escaped = True
continue
if ch == "'" and not in_double:
in_single = not in_single
elif ch == '"' and not in_single:
in_double = not in_double

return in_single or in_double


def suggest_type(full_text: str, text_before_cursor: str) -> list[dict[str, Any]]:
"""Takes the full_text that is typed so far and also the text before the
Expand Down Expand Up @@ -133,8 +184,13 @@ def suggest_based_on_last_token(
# list. This means that token.value may be something like
# 'where foo > 5 and '. We need to look "inside" token.tokens to handle
# suggestions in complicated where clauses correctly
original_text = text_before_cursor
prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
return suggest_based_on_last_token(prev_keyword, text_before_cursor, full_text, identifier)
enum_suggestion = _enum_value_suggestion(original_text, full_text)
fallback = suggest_based_on_last_token(prev_keyword, text_before_cursor, full_text, identifier)
if enum_suggestion and _is_where_or_having(prev_keyword):
return [enum_suggestion] + fallback
return fallback
elif token is None:
return [{"type": "keyword"}]
else:
Expand Down Expand Up @@ -291,11 +347,13 @@ def suggest_based_on_last_token(
elif token_v == "tableformat":
return [{"type": "table_format"}]
elif token_v.endswith(",") or is_operand(token_v) or token_v in ["=", "and", "or"]:
original_text = text_before_cursor
prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
if prev_keyword:
return suggest_based_on_last_token(prev_keyword, text_before_cursor, full_text, identifier)
else:
return []
enum_suggestion = _enum_value_suggestion(original_text, full_text)
fallback = suggest_based_on_last_token(prev_keyword, text_before_cursor, full_text, identifier) if prev_keyword else []
if enum_suggestion and _is_where_or_having(prev_keyword):
return [enum_suggestion] + fallback
return fallback
else:
return [{"type": "keyword"}]

Expand Down
71 changes: 70 additions & 1 deletion mycli/sqlcompleter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,17 @@ def extend_columns(self, column_data: list[tuple[str, str]], kind: Literal['tabl
metadata[self.dbname][relname].append(column)
self.all_completions.add(column)

def extend_enum_values(self, enum_data: Iterable[tuple[str, str, list[str]]]) -> None:
metadata = self.dbmetadata["enum_values"]
if self.dbname not in metadata:
metadata[self.dbname] = {}

for relname, column, values in enum_data:
relname_escaped = self.escape_name(relname)
column_escaped = self.escape_name(column)
table_meta = metadata[self.dbname].setdefault(relname_escaped, {})
table_meta[column_escaped] = values

def extend_functions(self, func_data: list[str] | Generator[tuple[str, str]], builtin: bool = False) -> None:
# if 'builtin' is set this is extending the list of builtin functions
if builtin:
Expand Down Expand Up @@ -1048,7 +1059,7 @@ def reset_completions(self) -> None:
self.users: list[str] = []
self.show_items: list[Completion] = []
self.dbname = ""
self.dbmetadata: dict[str, Any] = {"tables": {}, "views": {}, "functions": {}}
self.dbmetadata: dict[str, Any] = {"tables": {}, "views": {}, "functions": {}, "enum_values": {}}
self.all_completions = set(self.keywords + self.functions)

@staticmethod
Expand Down Expand Up @@ -1217,6 +1228,15 @@ def get_completions(
fuzzy=True,
)
completions.extend(subcommands_m)
elif suggestion["type"] == "enum_value":
enum_values = self.populate_enum_values(
suggestion["tables"],
suggestion["column"],
suggestion.get("parent"),
)
if enum_values:
quoted_values = [self._quote_sql_string(value) for value in enum_values]
return list(self.find_matches(word_before_cursor, quoted_values))

return completions

Expand Down Expand Up @@ -1272,6 +1292,55 @@ def populate_scoped_cols(self, scoped_tbls: list[tuple[str | None, str, str | No

return columns

def populate_enum_values(
self,
scoped_tbls: list[tuple[str | None, str, str | None]],
column: str,
parent: str | None = None,
) -> list[str]:
values: list[str] = []
meta = self.dbmetadata["enum_values"]
column_key = self._escape_identifier(column)
parent_key = self._strip_backticks(parent) if parent else None

for schema, relname, alias in scoped_tbls:
if parent_key and not self._matches_parent(parent_key, schema, relname, alias):
continue

schema = schema or self.dbname
table_meta = meta.get(schema, {})
escaped_relname = self.escape_name(relname)

for rel_key in {relname, escaped_relname}:
columns = table_meta.get(rel_key)
if columns and column_key in columns:
values.extend(columns[column_key])

return list(dict.fromkeys(values))

def _escape_identifier(self, name: str) -> str:
return self.escape_name(self._strip_backticks(name))

@staticmethod
def _strip_backticks(name: str | None) -> str:
if name and name[0] == "`" and name[-1] == "`":
return name[1:-1]
return name or ""

@staticmethod
def _matches_parent(parent: str, schema: str | None, relname: str, alias: str | None) -> bool:
if alias and parent == alias:
return True
if parent == relname:
return True
if schema and parent == f"{schema}.{relname}":
return True
return False

@staticmethod
def _quote_sql_string(value: str) -> str:
return "'" + value.replace("'", "''") + "'"

def populate_schema_objects(self, schema: str | None, obj_type: str) -> list[str]:
"""Returns list of tables or functions for a (optional) schema"""
metadata = self.dbmetadata[obj_type]
Expand Down
51 changes: 51 additions & 0 deletions mycli/sqlexecute.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,48 @@ class SQLExecute:
where table_schema = '%s'
order by table_name,ordinal_position"""

enum_values_query = """select TABLE_NAME, COLUMN_NAME, COLUMN_TYPE from information_schema.columns
where table_schema = '%s' and data_type = 'enum'
order by table_name,ordinal_position"""

now_query = """SELECT NOW()"""

@staticmethod
def _parse_enum_values(column_type: str) -> list[str]:
if not column_type or not column_type.lower().startswith("enum("):
return []

values = []
current = []
in_quote = False
i = column_type.find("(") + 1

while i < len(column_type):
ch = column_type[i]

if not in_quote:
if ch == "'":
in_quote = True
current = []
elif ch == ")":
break
else:
if ch == "\\" and i + 1 < len(column_type):
current.append(column_type[i + 1])
i += 1
elif ch == "'":
if i + 1 < len(column_type) and column_type[i + 1] == "'":
current.append("'")
i += 1
else:
values.append("".join(current))
in_quote = False
else:
current.append(ch)
i += 1

return values

def __init__(
self,
database: str | None,
Expand Down Expand Up @@ -375,6 +415,17 @@ def table_columns(self) -> Generator[tuple[str, str], None, None]:
for row in cur:
yield row

def enum_values(self) -> Generator[tuple[str, str, list[str]], None, None]:
"""Yields (table name, column name, enum values) tuples"""
assert isinstance(self.conn, Connection)
with self.conn.cursor() as cur:
_logger.debug("Enum Values Query. sql: %r", self.enum_values_query)
cur.execute(self.enum_values_query % self.dbname)
for table_name, column_name, column_type in cur:
values = self._parse_enum_values(column_type)
if values:
yield (table_name, column_name, values)

def databases(self) -> list[str]:
assert isinstance(self.conn, Connection)
with self.conn.cursor() as cur:
Expand Down
13 changes: 12 additions & 1 deletion test/test_completion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def test_select_suggests_cols_with_qualified_table_scope():
[
"SELECT * FROM tabl WHERE ",
"SELECT * FROM tabl WHERE (",
"SELECT * FROM tabl WHERE foo = ",
"SELECT * FROM tabl WHERE bar OR ",
"SELECT * FROM tabl WHERE foo = 1 AND ",
"SELECT * FROM tabl WHERE (bar > 10 AND ",
Expand All @@ -55,6 +54,18 @@ def test_where_suggests_columns_functions(expression):
])


def test_where_equals_suggests_enum_values_first():
expression = "SELECT * FROM tabl WHERE foo = "
suggestions = suggest_type(expression, expression)
assert sorted_dicts(suggestions) == sorted_dicts([
{"type": "enum_value", "tables": [(None, "tabl", None)], "column": "foo", "parent": None},
{"type": "alias", "aliases": ["tabl"]},
{"type": "column", "tables": [(None, "tabl", None)]},
{"type": "function", "schema": []},
{"type": "keyword"},
])


@pytest.mark.parametrize(
"expression",
[
Expand Down
12 changes: 11 additions & 1 deletion test/test_completion_refresher.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,17 @@ def test_ctor(refresher):
"""
assert len(refresher.refreshers) > 0
actual_handlers = list(refresher.refreshers.keys())
expected_handlers = ["databases", "schemata", "tables", "users", "functions", "special_commands", "show_commands", "keywords"]
expected_handlers = [
"databases",
"schemata",
"tables",
"enum_values",
"users",
"functions",
"special_commands",
"show_commands",
"keywords",
]
assert expected_handlers == actual_handlers


Expand Down
4 changes: 3 additions & 1 deletion test/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,9 @@ def stub_terminal_size():
assert isinstance(mycli.get_reserved_space(), int)


def test_list_dsn():
def test_list_dsn(monkeypatch):
monkeypatch.setattr(MyCli, "system_config_files", [])
monkeypatch.setattr(MyCli, "pwd_config_file", os.path.join(test_dir, "does_not_exist.myclirc"))
runner = CliRunner()
# keep Windows from locking the file with delete=False
with NamedTemporaryFile(mode="w", delete=False) as myclirc:
Expand Down
11 changes: 11 additions & 0 deletions test/test_smart_completion_public_schema_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def completer():
comp.extend_schemata("test")
comp.extend_relations(tables, kind="tables")
comp.extend_columns(columns, kind="tables")
comp.extend_enum_values([("orders", "status", ["pending", "shipped"])])
comp.extend_special_commands(special.COMMANDS)

return comp
Expand Down Expand Up @@ -84,6 +85,16 @@ def test_table_completion(completer, complete_event):
]


def test_enum_value_completion(completer, complete_event):
text = "SELECT * FROM orders WHERE status = "
position = len(text)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == [
Completion(text="'pending'", start_position=0),
Completion(text="'shipped'", start_position=0),
]


def test_function_name_completion(completer, complete_event):
text = "SELECT MA"
position = len("SELECT MA")
Expand Down
Loading