Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
Upcoming (TBD)
==============

Features
--------

* Add enum value completions for WHERE/HAVING clauses. (#790)


1.43.1 (2026/01/03)
==============

Expand Down
5 changes: 5 additions & 0 deletions mycli/completion_refresher.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ def refresh_tables(completer: SQLCompleter, executor: SQLExecute) -> None:
completer.extend_columns(table_columns_dbresult, kind="tables")


@refresher("enum_values")
def refresh_enum_values(completer: SQLCompleter, executor: SQLExecute) -> None:
completer.extend_enum_values(executor.enum_values())


@refresher("users")
def refresh_users(completer: SQLCompleter, executor: SQLExecute) -> None:
completer.extend_users(executor.users())
Expand Down
68 changes: 63 additions & 5 deletions mycli/packages/completion_engine.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from typing import Any

import sqlparse
Expand All @@ -6,6 +7,56 @@
from mycli.packages.parseutils import extract_tables, find_prev_keyword, last_word
from mycli.packages.special.main import parse_special_command

_ENUM_VALUE_RE = re.compile(
r"(?P<lhs>(?:`[^`]+`|[\w$]+)(?:\.(?:`[^`]+`|[\w$]+))?)\s*=\s*$",
re.IGNORECASE,
)


def _enum_value_suggestion(text_before_cursor: str, full_text: str) -> dict[str, Any] | None:
match = _ENUM_VALUE_RE.search(text_before_cursor)
if not match:
return None
if _is_inside_quotes(text_before_cursor, match.start("lhs")):
return None

lhs = match.group("lhs")
if "." in lhs:
parent, column = lhs.split(".", 1)
else:
parent, column = None, lhs

return {
"type": "enum_value",
"tables": extract_tables(full_text),
"column": column,
"parent": parent,
}


def _is_where_or_having(token: Token | None) -> bool:
return bool(token and token.value and token.value.lower() in ("where", "having"))


def _is_inside_quotes(text: str, pos: int) -> bool:
in_single = False
in_double = False
escaped = False

for ch in text[:pos]:
if escaped:
escaped = False
continue
if ch == "\\":
escaped = True
continue
if ch == "'" and not in_double:
in_single = not in_single
elif ch == '"' and not in_single:
in_double = not in_double

return in_single or in_double


def suggest_type(full_text: str, text_before_cursor: str) -> list[dict[str, Any]]:
"""Takes the full_text that is typed so far and also the text before the
Expand Down Expand Up @@ -133,8 +184,13 @@ def suggest_based_on_last_token(
# list. This means that token.value may be something like
# 'where foo > 5 and '. We need to look "inside" token.tokens to handle
# suggestions in complicated where clauses correctly
original_text = text_before_cursor
prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
return suggest_based_on_last_token(prev_keyword, text_before_cursor, full_text, identifier)
enum_suggestion = _enum_value_suggestion(original_text, full_text)
fallback = suggest_based_on_last_token(prev_keyword, text_before_cursor, full_text, identifier)
if enum_suggestion and _is_where_or_having(prev_keyword):
return [enum_suggestion] + fallback
return fallback
elif token is None:
return [{"type": "keyword"}]
else:
Expand Down Expand Up @@ -291,11 +347,13 @@ def suggest_based_on_last_token(
elif token_v == "tableformat":
return [{"type": "table_format"}]
elif token_v.endswith(",") or is_operand(token_v) or token_v in ["=", "and", "or"]:
original_text = text_before_cursor
prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
if prev_keyword:
return suggest_based_on_last_token(prev_keyword, text_before_cursor, full_text, identifier)
else:
return []
enum_suggestion = _enum_value_suggestion(original_text, full_text)
fallback = suggest_based_on_last_token(prev_keyword, text_before_cursor, full_text, identifier) if prev_keyword else []
if enum_suggestion and _is_where_or_having(prev_keyword):
return [enum_suggestion] + fallback
return fallback
else:
return [{"type": "keyword"}]

Expand Down
71 changes: 70 additions & 1 deletion mycli/sqlcompleter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,17 @@ def extend_columns(self, column_data: list[tuple[str, str]], kind: Literal['tabl
metadata[self.dbname][relname].append(column)
self.all_completions.add(column)

def extend_enum_values(self, enum_data: Iterable[tuple[str, str, list[str]]]) -> None:
metadata = self.dbmetadata["enum_values"]
if self.dbname not in metadata:
metadata[self.dbname] = {}

for relname, column, values in enum_data:
relname_escaped = self.escape_name(relname)
column_escaped = self.escape_name(column)
table_meta = metadata[self.dbname].setdefault(relname_escaped, {})
table_meta[column_escaped] = values

def extend_functions(self, func_data: list[str] | Generator[tuple[str, str]], builtin: bool = False) -> None:
# if 'builtin' is set this is extending the list of builtin functions
if builtin:
Expand Down Expand Up @@ -1048,7 +1059,7 @@ def reset_completions(self) -> None:
self.users: list[str] = []
self.show_items: list[Completion] = []
self.dbname = ""
self.dbmetadata: dict[str, Any] = {"tables": {}, "views": {}, "functions": {}}
self.dbmetadata: dict[str, Any] = {"tables": {}, "views": {}, "functions": {}, "enum_values": {}}
self.all_completions = set(self.keywords + self.functions)

@staticmethod
Expand Down Expand Up @@ -1217,6 +1228,15 @@ def get_completions(
fuzzy=True,
)
completions.extend(subcommands_m)
elif suggestion["type"] == "enum_value":
enum_values = self.populate_enum_values(
suggestion["tables"],
suggestion["column"],
suggestion.get("parent"),
)
if enum_values:
quoted_values = [self._quote_sql_string(value) for value in enum_values]
return list(self.find_matches(word_before_cursor, quoted_values))

return completions

Expand Down Expand Up @@ -1272,6 +1292,55 @@ def populate_scoped_cols(self, scoped_tbls: list[tuple[str | None, str, str | No

return columns

def populate_enum_values(
self,
scoped_tbls: list[tuple[str | None, str, str | None]],
column: str,
parent: str | None = None,
) -> list[str]:
values: list[str] = []
meta = self.dbmetadata["enum_values"]
column_key = self._escape_identifier(column)
parent_key = self._strip_backticks(parent) if parent else None

for schema, relname, alias in scoped_tbls:
if parent_key and not self._matches_parent(parent_key, schema, relname, alias):
continue

schema = schema or self.dbname
table_meta = meta.get(schema, {})
escaped_relname = self.escape_name(relname)

for rel_key in {relname, escaped_relname}:
columns = table_meta.get(rel_key)
if columns and column_key in columns:
values.extend(columns[column_key])

return list(dict.fromkeys(values))

def _escape_identifier(self, name: str) -> str:
return self.escape_name(self._strip_backticks(name))

@staticmethod
def _strip_backticks(name: str | None) -> str:
if name and name[0] == "`" and name[-1] == "`":
return name[1:-1]
return name or ""

@staticmethod
def _matches_parent(parent: str, schema: str | None, relname: str, alias: str | None) -> bool:
if alias and parent == alias:
return True
if parent == relname:
return True
if schema and parent == f"{schema}.{relname}":
return True
return False

@staticmethod
def _quote_sql_string(value: str) -> str:
return "'" + value.replace("'", "''") + "'"

def populate_schema_objects(self, schema: str | None, obj_type: str) -> list[str]:
"""Returns list of tables or functions for a (optional) schema"""
metadata = self.dbmetadata[obj_type]
Expand Down
51 changes: 51 additions & 0 deletions mycli/sqlexecute.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,48 @@ class SQLExecute:
where table_schema = '%s'
order by table_name,ordinal_position"""

enum_values_query = """select TABLE_NAME, COLUMN_NAME, COLUMN_TYPE from information_schema.columns
where table_schema = '%s' and data_type = 'enum'
order by table_name,ordinal_position"""

now_query = """SELECT NOW()"""

@staticmethod
def _parse_enum_values(column_type: str) -> list[str]:
if not column_type or not column_type.lower().startswith("enum("):
return []

values: list[str] = []
current: list[str] = []
in_quote = False
i = column_type.find("(") + 1

while i < len(column_type):
ch = column_type[i]

if not in_quote:
if ch == "'":
in_quote = True
current = []
elif ch == ")":
break
else:
if ch == "\\" and i + 1 < len(column_type):
current.append(column_type[i + 1])
i += 1
elif ch == "'":
if i + 1 < len(column_type) and column_type[i + 1] == "'":
current.append("'")
i += 1
else:
values.append("".join(current))
in_quote = False
else:
current.append(ch)
i += 1

return values

def __init__(
self,
database: str | None,
Expand Down Expand Up @@ -375,6 +415,17 @@ def table_columns(self) -> Generator[tuple[str, str], None, None]:
for row in cur:
yield row

def enum_values(self) -> Generator[tuple[str, str, list[str]], None, None]:
"""Yields (table name, column name, enum values) tuples"""
assert isinstance(self.conn, Connection)
with self.conn.cursor() as cur:
_logger.debug("Enum Values Query. sql: %r", self.enum_values_query)
cur.execute(self.enum_values_query % self.dbname)
for table_name, column_name, column_type in cur:
values = self._parse_enum_values(column_type)
if values:
yield (table_name, column_name, values)

def databases(self) -> list[str]:
assert isinstance(self.conn, Connection)
with self.conn.cursor() as cur:
Expand Down
13 changes: 12 additions & 1 deletion test/test_completion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def test_select_suggests_cols_with_qualified_table_scope():
[
"SELECT * FROM tabl WHERE ",
"SELECT * FROM tabl WHERE (",
"SELECT * FROM tabl WHERE foo = ",
"SELECT * FROM tabl WHERE bar OR ",
"SELECT * FROM tabl WHERE foo = 1 AND ",
"SELECT * FROM tabl WHERE (bar > 10 AND ",
Expand All @@ -55,6 +54,18 @@ def test_where_suggests_columns_functions(expression):
])


def test_where_equals_suggests_enum_values_first():
expression = "SELECT * FROM tabl WHERE foo = "
suggestions = suggest_type(expression, expression)
assert sorted_dicts(suggestions) == sorted_dicts([
{"type": "enum_value", "tables": [(None, "tabl", None)], "column": "foo", "parent": None},
{"type": "alias", "aliases": ["tabl"]},
{"type": "column", "tables": [(None, "tabl", None)]},
{"type": "function", "schema": []},
{"type": "keyword"},
])


@pytest.mark.parametrize(
"expression",
[
Expand Down
12 changes: 11 additions & 1 deletion test/test_completion_refresher.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,17 @@ def test_ctor(refresher):
"""
assert len(refresher.refreshers) > 0
actual_handlers = list(refresher.refreshers.keys())
expected_handlers = ["databases", "schemata", "tables", "users", "functions", "special_commands", "show_commands", "keywords"]
expected_handlers = [
"databases",
"schemata",
"tables",
"enum_values",
"users",
"functions",
"special_commands",
"show_commands",
"keywords",
]
assert expected_handlers == actual_handlers


Expand Down
4 changes: 3 additions & 1 deletion test/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,9 @@ def stub_terminal_size():
assert isinstance(mycli.get_reserved_space(), int)


def test_list_dsn():
def test_list_dsn(monkeypatch):
monkeypatch.setattr(MyCli, "system_config_files", [])
monkeypatch.setattr(MyCli, "pwd_config_file", os.path.join(test_dir, "does_not_exist.myclirc"))
runner = CliRunner()
# keep Windows from locking the file with delete=False
with NamedTemporaryFile(mode="w", delete=False) as myclirc:
Expand Down
11 changes: 11 additions & 0 deletions test/test_smart_completion_public_schema_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def completer():
comp.extend_schemata("test")
comp.extend_relations(tables, kind="tables")
comp.extend_columns(columns, kind="tables")
comp.extend_enum_values([("orders", "status", ["pending", "shipped"])])
comp.extend_special_commands(special.COMMANDS)

return comp
Expand Down Expand Up @@ -84,6 +85,16 @@ def test_table_completion(completer, complete_event):
]


def test_enum_value_completion(completer, complete_event):
text = "SELECT * FROM orders WHERE status = "
position = len(text)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == [
Completion(text="'pending'", start_position=0),
Completion(text="'shipped'", start_position=0),
]


def test_function_name_completion(completer, complete_event):
text = "SELECT MA"
position = len("SELECT MA")
Expand Down