diff --git a/CHANGES.rst b/CHANGES.rst index 057d700df..874692407 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -88,13 +88,16 @@ Unreleased - Parameters cannot be required nor prompted or an error is raised. - A warning will be printed when something deprecated is used. -- Add a ``catch_exceptions`` parameter to :class:`CliRunner`. If - ``catch_exceptions`` is not passed to :meth:`CliRunner.invoke`, +- Add a ``catch_exceptions`` parameter to :class:``CliRunner``. If + ``catch_exceptions`` is not passed to :meth:``CliRunner.invoke``, the value from :class:`CliRunner`. :issue:`2817` :pr:`2818` - ``Option.flag_value`` will no longer have a default value set based on ``Option.default`` if ``Option.is_flag`` is ``False``. This results in ``Option.default`` not needing to implement `__bool__`. :pr:`2829` - Incorrect ``click.edit`` typing has been corrected. :pr:`2804` +- :class:``Choice`` is now generic and supports any iterable value. + This allows you to use enums and other non-``str`` values. :pr:`2796` + :issue:`605` Version 8.1.8 ------------- diff --git a/docs/options.rst b/docs/options.rst index 39f66c4cd..a4791234d 100644 --- a/docs/options.rst +++ b/docs/options.rst @@ -375,15 +375,22 @@ In that case you can use :class:`Choice` type. It can be instantiated with a list of valid values. The originally passed choice will be returned, not the str passed on the command line. Token normalization functions and ``case_sensitive=False`` can cause the two to be different but still match. +:meth:`Choice.normalize_choice` for more info. Example: .. click:example:: + import enum + + class HashType(enum.Enum): + MD5 = 'MD5' + SHA1 = 'SHA1' + @click.command() @click.option('--hash-type', - type=click.Choice(['MD5', 'SHA1'], case_sensitive=False)) - def digest(hash_type): + type=click.Choice(HashType, case_sensitive=False)) + def digest(hash_type: HashType): click.echo(hash_type) What it looks like: @@ -398,15 +405,16 @@ What it looks like: println() invoke(digest, args=['--help']) -Only pass the choices as list or tuple. Other iterables (like -generators) may lead to unexpected results. +Since version 8.2.0 any iterable may be passed to :class:`Choice`, here +an ``Enum`` is used which will result in all enum values to be valid +choices. Choices work with options that have ``multiple=True``. If a ``default`` value is given with ``multiple=True``, it should be a list or tuple of valid choices. -Choices should be unique after considering the effects of -``case_sensitive`` and any specified token normalization function. +Choices should be unique after normalization, see +:meth:`Choice.normalize_choice` for more info. .. versionchanged:: 7.1 The resulting value from an option will always be one of the diff --git a/src/click/core.py b/src/click/core.py index e783729c7..176a7ca60 100644 --- a/src/click/core.py +++ b/src/click/core.py @@ -2192,11 +2192,11 @@ def human_readable_name(self) -> str: """ return self.name # type: ignore - def make_metavar(self) -> str: + def make_metavar(self, ctx: Context) -> str: if self.metavar is not None: return self.metavar - metavar = self.type.get_metavar(self) + metavar = self.type.get_metavar(param=self, ctx=ctx) if metavar is None: metavar = self.type.name.upper() @@ -2775,7 +2775,7 @@ def _write_opts(opts: cabc.Sequence[str]) -> str: any_prefix_is_slash = True if not self.is_flag and not self.count: - rv += f" {self.make_metavar()}" + rv += f" {self.make_metavar(ctx=ctx)}" return rv @@ -3056,10 +3056,10 @@ def human_readable_name(self) -> str: return self.metavar return self.name.upper() # type: ignore - def make_metavar(self) -> str: + def make_metavar(self, ctx: Context) -> str: if self.metavar is not None: return self.metavar - var = self.type.get_metavar(self) + var = self.type.get_metavar(param=self, ctx=ctx) if not var: var = self.name.upper() # type: ignore if self.deprecated: @@ -3088,10 +3088,10 @@ def _parse_decls( return name, [arg], [] def get_usage_pieces(self, ctx: Context) -> list[str]: - return [self.make_metavar()] + return [self.make_metavar(ctx)] def get_error_hint(self, ctx: Context) -> str: - return f"'{self.make_metavar()}'" + return f"'{self.make_metavar(ctx)}'" def add_to_parser(self, parser: _OptionParser, ctx: Context) -> None: parser.add_argument(dest=self.name, nargs=self.nargs, obj=self) diff --git a/src/click/exceptions.py b/src/click/exceptions.py index c41c20676..f141a832e 100644 --- a/src/click/exceptions.py +++ b/src/click/exceptions.py @@ -174,7 +174,9 @@ def format_message(self) -> str: msg = self.message if self.param is not None: - msg_extra = self.param.type.get_missing_message(self.param) + msg_extra = self.param.type.get_missing_message( + param=self.param, ctx=self.ctx + ) if msg_extra: if msg: msg += f". {msg_extra}" diff --git a/src/click/types.py b/src/click/types.py index 354c7e381..d0a2715d2 100644 --- a/src/click/types.py +++ b/src/click/types.py @@ -1,6 +1,7 @@ from __future__ import annotations import collections.abc as cabc +import enum import os import stat import sys @@ -23,6 +24,8 @@ from .core import Parameter from .shell_completion import CompletionItem +ParamTypeValue = t.TypeVar("ParamTypeValue") + class ParamType: """Represents the type of a parameter. Validates and converts values @@ -86,10 +89,10 @@ def __call__( if value is not None: return self.convert(value, param, ctx) - def get_metavar(self, param: Parameter) -> str | None: + def get_metavar(self, param: Parameter, ctx: Context) -> str | None: """Returns the metavar default for this param if it provides one.""" - def get_missing_message(self, param: Parameter) -> str | None: + def get_missing_message(self, param: Parameter, ctx: Context | None) -> str | None: """Optionally might return extra information about a missing parameter. @@ -227,29 +230,35 @@ def __repr__(self) -> str: return "STRING" -class Choice(ParamType): +class Choice(ParamType, t.Generic[ParamTypeValue]): """The choice type allows a value to be checked against a fixed set - of supported values. All of these values have to be strings. - - You should only pass a list or tuple of choices. Other iterables - (like generators) may lead to surprising results. + of supported values. - The resulting value will always be one of the originally passed choices - regardless of ``case_sensitive`` or any ``ctx.token_normalize_func`` - being specified. + You may pass any iterable value which will be converted to a tuple + and thus will only be iterated once. - See :ref:`choice-opts` for an example. + The resulting value will always be one of the originally passed choices. + See :meth:`normalize_choice` for more info on the mapping of strings + to choices. See :ref:`choice-opts` for an example. :param case_sensitive: Set to false to make choices case insensitive. Defaults to true. + + .. versionchanged:: 8.2.0 + Non-``str`` ``choices`` are now supported. It can additionally be any + iterable. Before you were not recommended to pass anything but a list or + tuple. + + .. versionadded:: 8.2.0 + Choice normalization can be overridden via :meth:`normalize_choice`. """ name = "choice" def __init__( - self, choices: cabc.Sequence[str], case_sensitive: bool = True + self, choices: cabc.Iterable[ParamTypeValue], case_sensitive: bool = True ) -> None: - self.choices = choices + self.choices: cabc.Sequence[ParamTypeValue] = tuple(choices) self.case_sensitive = case_sensitive def to_info_dict(self) -> dict[str, t.Any]: @@ -258,14 +267,54 @@ def to_info_dict(self) -> dict[str, t.Any]: info_dict["case_sensitive"] = self.case_sensitive return info_dict - def get_metavar(self, param: Parameter) -> str: + def _normalized_mapping( + self, ctx: Context | None = None + ) -> cabc.Mapping[ParamTypeValue, str]: + """ + Returns mapping where keys are the original choices and the values are + the normalized values that are accepted via the command line. + + This is a simple wrapper around :meth:`normalize_choice`, use that + instead which is supported. + """ + return { + choice: self.normalize_choice( + choice=choice, + ctx=ctx, + ) + for choice in self.choices + } + + def normalize_choice(self, choice: ParamTypeValue, ctx: Context | None) -> str: + """ + Normalize a choice value, used to map a passed string to a choice. + Each choice must have a unique normalized value. + + By default uses :meth:`Context.token_normalize_func` and if not case + sensitive, convert it to a casefolded value. + + .. versionadded:: 8.2.0 + """ + normed_value = choice.name if isinstance(choice, enum.Enum) else str(choice) + + if ctx is not None and ctx.token_normalize_func is not None: + normed_value = ctx.token_normalize_func(normed_value) + + if not self.case_sensitive: + normed_value = normed_value.casefold() + + return normed_value + + def get_metavar(self, param: Parameter, ctx: Context) -> str | None: if param.param_type_name == "option" and not param.show_choices: # type: ignore choice_metavars = [ convert_type(type(choice)).name.upper() for choice in self.choices ] choices_str = "|".join([*dict.fromkeys(choice_metavars)]) else: - choices_str = "|".join([str(i) for i in self.choices]) + choices_str = "|".join( + [str(i) for i in self._normalized_mapping(ctx=ctx).values()] + ) # Use curly braces to indicate a required argument. if param.required and param.param_type_name == "argument": @@ -274,46 +323,48 @@ def get_metavar(self, param: Parameter) -> str: # Use square braces to indicate an option or optional argument. return f"[{choices_str}]" - def get_missing_message(self, param: Parameter) -> str: - return _("Choose from:\n\t{choices}").format(choices=",\n\t".join(self.choices)) + def get_missing_message(self, param: Parameter, ctx: Context | None) -> str: + """ + Message shown when no choice is passed. + + .. versionchanged:: 8.2.0 Added ``ctx`` argument. + """ + return _("Choose from:\n\t{choices}").format( + choices=",\n\t".join(self._normalized_mapping(ctx=ctx).values()) + ) def convert( self, value: t.Any, param: Parameter | None, ctx: Context | None - ) -> t.Any: - # Match through normalization and case sensitivity - # first do token_normalize_func, then lowercase - # preserve original `value` to produce an accurate message in - # `self.fail` - normed_value = value - normed_choices = {choice: choice for choice in self.choices} - - if ctx is not None and ctx.token_normalize_func is not None: - normed_value = ctx.token_normalize_func(value) - normed_choices = { - ctx.token_normalize_func(normed_choice): original - for normed_choice, original in normed_choices.items() - } - - if not self.case_sensitive: - normed_value = normed_value.casefold() - normed_choices = { - normed_choice.casefold(): original - for normed_choice, original in normed_choices.items() - } - - if normed_value in normed_choices: - return normed_choices[normed_value] + ) -> ParamTypeValue: + """ + For a given value from the parser, normalize it and find its + matching normalized value in the list of choices. Then return the + matched "original" choice. + """ + normed_value = self.normalize_choice(choice=value, ctx=ctx) + normalized_mapping = self._normalized_mapping(ctx=ctx) - self.fail(self.get_invalid_choice_message(value), param, ctx) + try: + return next( + original + for original, normalized in normalized_mapping.items() + if normalized == normed_value + ) + except StopIteration: + self.fail( + self.get_invalid_choice_message(value=value, ctx=ctx), + param=param, + ctx=ctx, + ) - def get_invalid_choice_message(self, value: t.Any) -> str: + def get_invalid_choice_message(self, value: t.Any, ctx: Context | None) -> str: """Get the error message when the given choice is invalid. :param value: The invalid value. .. versionadded:: 8.2 """ - choices_str = ", ".join(map(repr, self.choices)) + choices_str = ", ".join(map(repr, self._normalized_mapping(ctx=ctx).values())) return ngettext( "{value!r} is not {choice}.", "{value!r} is not one of {choices}.", @@ -382,7 +433,7 @@ def to_info_dict(self) -> dict[str, t.Any]: info_dict["formats"] = self.formats return info_dict - def get_metavar(self, param: Parameter) -> str: + def get_metavar(self, param: Parameter, ctx: Context) -> str | None: return f"[{'|'.join(self.formats)}]" def _try_to_convert_date(self, value: t.Any, format: str) -> datetime | None: diff --git a/tests/test_basic.py b/tests/test_basic.py index d68b96299..b84ae73d6 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +import enum import os from itertools import chain @@ -403,6 +406,82 @@ def cli(method): assert "{foo|bar|baz}" in result.output +def test_choice_argument_enum(runner): + class MyEnum(str, enum.Enum): + FOO = "foo-value" + BAR = "bar-value" + BAZ = "baz-value" + + @click.command() + @click.argument("method", type=click.Choice(MyEnum, case_sensitive=False)) + def cli(method: MyEnum): + assert isinstance(method, MyEnum) + click.echo(method) + + result = runner.invoke(cli, ["foo"]) + assert result.output == "foo-value\n" + assert not result.exception + + result = runner.invoke(cli, ["meh"]) + assert result.exit_code == 2 + assert ( + "Invalid value for '{foo|bar|baz}': 'meh' is not one of 'foo'," + " 'bar', 'baz'." in result.output + ) + + result = runner.invoke(cli, ["--help"]) + assert "{foo|bar|baz}" in result.output + + +def test_choice_argument_custom_type(runner): + class MyClass: + def __init__(self, value: str) -> None: + self.value = value + + def __str__(self) -> str: + return self.value + + @click.command() + @click.argument( + "method", type=click.Choice([MyClass("foo"), MyClass("bar"), MyClass("baz")]) + ) + def cli(method: MyClass): + assert isinstance(method, MyClass) + click.echo(method) + + result = runner.invoke(cli, ["foo"]) + assert not result.exception + assert result.output == "foo\n" + + result = runner.invoke(cli, ["meh"]) + assert result.exit_code == 2 + assert ( + "Invalid value for '{foo|bar|baz}': 'meh' is not one of 'foo'," + " 'bar', 'baz'." in result.output + ) + + result = runner.invoke(cli, ["--help"]) + assert "{foo|bar|baz}" in result.output + + +def test_choice_argument_none(runner): + @click.command() + @click.argument( + "method", type=click.Choice(["not-none", None], case_sensitive=False) + ) + def cli(method: str | None): + assert isinstance(method, str) or method is None + click.echo(method) + + result = runner.invoke(cli, ["not-none"]) + assert not result.exception + assert result.output == "not-none\n" + + # None is not yet supported. + result = runner.invoke(cli, ["none"]) + assert result.exception + + def test_datetime_option_default(runner): @click.command() @click.option("--start_date", type=click.DateTime()) diff --git a/tests/test_formatting.py b/tests/test_formatting.py index fe7e7bad1..c79f6577f 100644 --- a/tests/test_formatting.py +++ b/tests/test_formatting.py @@ -248,7 +248,7 @@ def cmd(arg): def test_formatting_custom_type_metavar(runner): class MyType(click.ParamType): - def get_metavar(self, param): + def get_metavar(self, param: click.Parameter, ctx: click.Context): return "MY_TYPE" @click.command("foo") diff --git a/tests/test_info_dict.py b/tests/test_info_dict.py index 11b670311..20fe68cc1 100644 --- a/tests/test_info_dict.py +++ b/tests/test_info_dict.py @@ -106,11 +106,11 @@ ), pytest.param(*STRING_PARAM_TYPE, id="STRING ParamType"), pytest.param( - click.Choice(["a", "b"]), + click.Choice(("a", "b")), { "param_type": "Choice", "name": "choice", - "choices": ["a", "b"], + "choices": ("a", "b"), "case_sensitive": True, }, id="Choice ParamType", diff --git a/tests/test_normalization.py b/tests/test_normalization.py index 502e654a3..442b638f4 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -17,12 +17,37 @@ def cli(foo, x): def test_choice_normalization(runner): @click.command(context_settings=CONTEXT_SETTINGS) - @click.option("--choice", type=click.Choice(["Foo", "Bar"])) - def cli(choice): - click.echo(choice) - - result = runner.invoke(cli, ["--CHOICE", "FOO"]) - assert result.output == "Foo\n" + @click.option( + "--method", + type=click.Choice( + ["SCREAMING_SNAKE_CASE", "snake_case", "PascalCase", "kebab-case"], + case_sensitive=False, + ), + ) + def cli(method): + click.echo(method) + + result = runner.invoke(cli, ["--METHOD=snake_case"]) + assert not result.exception, result.output + assert result.output == "snake_case\n" + + # Even though it's case sensitive, the choice's original value is preserved + result = runner.invoke(cli, ["--method=pascalcase"]) + assert not result.exception, result.output + assert result.output == "PascalCase\n" + + result = runner.invoke(cli, ["--method=meh"]) + assert result.exit_code == 2 + assert ( + "Invalid value for '--method': 'meh' is not one of " + "'screaming_snake_case', 'snake_case', 'pascalcase', 'kebab-case'." + ) in result.output + + result = runner.invoke(cli, ["--help"]) + assert ( + "--method [screaming_snake_case|snake_case|pascalcase|kebab-case]" + in result.output + ) def test_command_normalization(runner): diff --git a/tests/test_types.py b/tests/test_types.py index 667953a47..c287e371c 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -248,5 +248,5 @@ def test_invalid_path_with_esc_sequence(): def test_choice_get_invalid_choice_message(): choice = click.Choice(["a", "b", "c"]) - message = choice.get_invalid_choice_message("d") + message = choice.get_invalid_choice_message("d", ctx=None) assert message == "'d' is not one of 'a', 'b', 'c'."