Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,16 @@ Unreleased
- Parameters cannot be required nor prompted or an error is raised.
- A warning will be printed when something deprecated is used.

- Add a ``catch_exceptions`` parameter to :class:`CliRunner`. If
``catch_exceptions`` is not passed to :meth:`CliRunner.invoke`,
- Add a ``catch_exceptions`` parameter to :class:``CliRunner``. If
``catch_exceptions`` is not passed to :meth:``CliRunner.invoke``,
the value from :class:`CliRunner`. :issue:`2817` :pr:`2818`
- ``Option.flag_value`` will no longer have a default value set based on
``Option.default`` if ``Option.is_flag`` is ``False``. This results in
``Option.default`` not needing to implement `__bool__`. :pr:`2829`
- Incorrect ``click.edit`` typing has been corrected. :pr:`2804`
- :class:``Choice`` is now generic and supports any iterable value.
This allows you to use enums and other non-``str`` values. :pr:`2796`
:issue:`605`

Version 8.1.8
-------------
Expand Down
20 changes: 14 additions & 6 deletions docs/options.rst
Original file line number Diff line number Diff line change
Expand Up @@ -375,15 +375,22 @@ In that case you can use :class:`Choice` type. It can be instantiated
with a list of valid values. The originally passed choice will be returned,
not the str passed on the command line. Token normalization functions and
``case_sensitive=False`` can cause the two to be different but still match.
:meth:`Choice.normalize_choice` for more info.

Example:

.. click:example::

import enum

class HashType(enum.Enum):
MD5 = 'MD5'
SHA1 = 'SHA1'

@click.command()
@click.option('--hash-type',
type=click.Choice(['MD5', 'SHA1'], case_sensitive=False))
def digest(hash_type):
type=click.Choice(HashType, case_sensitive=False))
def digest(hash_type: HashType):
click.echo(hash_type)

What it looks like:
Expand All @@ -398,15 +405,16 @@ What it looks like:
println()
invoke(digest, args=['--help'])

Only pass the choices as list or tuple. Other iterables (like
generators) may lead to unexpected results.
Since version 8.2.0 any iterable may be passed to :class:`Choice`, here
an ``Enum`` is used which will result in all enum values to be valid
choices.

Choices work with options that have ``multiple=True``. If a ``default``
value is given with ``multiple=True``, it should be a list or tuple of
valid choices.

Choices should be unique after considering the effects of
``case_sensitive`` and any specified token normalization function.
Choices should be unique after normalization, see
:meth:`Choice.normalize_choice` for more info.

.. versionchanged:: 7.1
The resulting value from an option will always be one of the
Expand Down
14 changes: 7 additions & 7 deletions src/click/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2192,11 +2192,11 @@ def human_readable_name(self) -> str:
"""
return self.name # type: ignore

def make_metavar(self) -> str:
def make_metavar(self, ctx: Context) -> str:
if self.metavar is not None:
return self.metavar

metavar = self.type.get_metavar(self)
metavar = self.type.get_metavar(param=self, ctx=ctx)

if metavar is None:
metavar = self.type.name.upper()
Expand Down Expand Up @@ -2775,7 +2775,7 @@ def _write_opts(opts: cabc.Sequence[str]) -> str:
any_prefix_is_slash = True

if not self.is_flag and not self.count:
rv += f" {self.make_metavar()}"
rv += f" {self.make_metavar(ctx=ctx)}"

return rv

Expand Down Expand Up @@ -3056,10 +3056,10 @@ def human_readable_name(self) -> str:
return self.metavar
return self.name.upper() # type: ignore

def make_metavar(self) -> str:
def make_metavar(self, ctx: Context) -> str:
if self.metavar is not None:
return self.metavar
var = self.type.get_metavar(self)
var = self.type.get_metavar(param=self, ctx=ctx)
if not var:
var = self.name.upper() # type: ignore
if self.deprecated:
Expand Down Expand Up @@ -3088,10 +3088,10 @@ def _parse_decls(
return name, [arg], []

def get_usage_pieces(self, ctx: Context) -> list[str]:
return [self.make_metavar()]
return [self.make_metavar(ctx)]

def get_error_hint(self, ctx: Context) -> str:
return f"'{self.make_metavar()}'"
return f"'{self.make_metavar(ctx)}'"

def add_to_parser(self, parser: _OptionParser, ctx: Context) -> None:
parser.add_argument(dest=self.name, nargs=self.nargs, obj=self)
Expand Down
4 changes: 3 additions & 1 deletion src/click/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ def format_message(self) -> str:

msg = self.message
if self.param is not None:
msg_extra = self.param.type.get_missing_message(self.param)
msg_extra = self.param.type.get_missing_message(
param=self.param, ctx=self.ctx
)
if msg_extra:
if msg:
msg += f". {msg_extra}"
Expand Down
141 changes: 96 additions & 45 deletions src/click/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import collections.abc as cabc
import enum
import os
import stat
import sys
Expand All @@ -23,6 +24,8 @@
from .core import Parameter
from .shell_completion import CompletionItem

ParamTypeValue = t.TypeVar("ParamTypeValue")


class ParamType:
"""Represents the type of a parameter. Validates and converts values
Expand Down Expand Up @@ -86,10 +89,10 @@ def __call__(
if value is not None:
return self.convert(value, param, ctx)

def get_metavar(self, param: Parameter) -> str | None:
def get_metavar(self, param: Parameter, ctx: Context) -> str | None:
"""Returns the metavar default for this param if it provides one."""

def get_missing_message(self, param: Parameter) -> str | None:
def get_missing_message(self, param: Parameter, ctx: Context | None) -> str | None:
"""Optionally might return extra information about a missing
parameter.

Expand Down Expand Up @@ -227,29 +230,35 @@ def __repr__(self) -> str:
return "STRING"


class Choice(ParamType):
class Choice(ParamType, t.Generic[ParamTypeValue]):
"""The choice type allows a value to be checked against a fixed set
of supported values. All of these values have to be strings.

You should only pass a list or tuple of choices. Other iterables
(like generators) may lead to surprising results.
of supported values.

The resulting value will always be one of the originally passed choices
regardless of ``case_sensitive`` or any ``ctx.token_normalize_func``
being specified.
You may pass any iterable value which will be converted to a tuple
and thus will only be iterated once.

See :ref:`choice-opts` for an example.
The resulting value will always be one of the originally passed choices.
See :meth:`normalize_choice` for more info on the mapping of strings
to choices. See :ref:`choice-opts` for an example.

:param case_sensitive: Set to false to make choices case
insensitive. Defaults to true.

.. versionchanged:: 8.2.0
Non-``str`` ``choices`` are now supported. It can additionally be any
iterable. Before you were not recommended to pass anything but a list or
tuple.

.. versionadded:: 8.2.0
Choice normalization can be overridden via :meth:`normalize_choice`.
"""

name = "choice"

def __init__(
self, choices: cabc.Sequence[str], case_sensitive: bool = True
self, choices: cabc.Iterable[ParamTypeValue], case_sensitive: bool = True
) -> None:
self.choices = choices
self.choices: cabc.Sequence[ParamTypeValue] = tuple(choices)
self.case_sensitive = case_sensitive

def to_info_dict(self) -> dict[str, t.Any]:
Expand All @@ -258,14 +267,54 @@ def to_info_dict(self) -> dict[str, t.Any]:
info_dict["case_sensitive"] = self.case_sensitive
return info_dict

def get_metavar(self, param: Parameter) -> str:
def _normalized_mapping(
self, ctx: Context | None = None
) -> cabc.Mapping[ParamTypeValue, str]:
"""
Returns mapping where keys are the original choices and the values are
the normalized values that are accepted via the command line.

This is a simple wrapper around :meth:`normalize_choice`, use that
instead which is supported.
"""
return {
choice: self.normalize_choice(
choice=choice,
ctx=ctx,
)
for choice in self.choices
}

def normalize_choice(self, choice: ParamTypeValue, ctx: Context | None) -> str:
"""
Normalize a choice value, used to map a passed string to a choice.
Each choice must have a unique normalized value.

By default uses :meth:`Context.token_normalize_func` and if not case
sensitive, convert it to a casefolded value.

.. versionadded:: 8.2.0
"""
normed_value = choice.name if isinstance(choice, enum.Enum) else str(choice)

if ctx is not None and ctx.token_normalize_func is not None:
normed_value = ctx.token_normalize_func(normed_value)

if not self.case_sensitive:
normed_value = normed_value.casefold()

return normed_value

def get_metavar(self, param: Parameter, ctx: Context) -> str | None:
if param.param_type_name == "option" and not param.show_choices: # type: ignore
choice_metavars = [
convert_type(type(choice)).name.upper() for choice in self.choices
]
choices_str = "|".join([*dict.fromkeys(choice_metavars)])
else:
choices_str = "|".join([str(i) for i in self.choices])
choices_str = "|".join(
[str(i) for i in self._normalized_mapping(ctx=ctx).values()]
)

# Use curly braces to indicate a required argument.
if param.required and param.param_type_name == "argument":
Expand All @@ -274,46 +323,48 @@ def get_metavar(self, param: Parameter) -> str:
# Use square braces to indicate an option or optional argument.
return f"[{choices_str}]"

def get_missing_message(self, param: Parameter) -> str:
return _("Choose from:\n\t{choices}").format(choices=",\n\t".join(self.choices))
def get_missing_message(self, param: Parameter, ctx: Context | None) -> str:
"""
Message shown when no choice is passed.

.. versionchanged:: 8.2.0 Added ``ctx`` argument.
"""
return _("Choose from:\n\t{choices}").format(
choices=",\n\t".join(self._normalized_mapping(ctx=ctx).values())
)

def convert(
self, value: t.Any, param: Parameter | None, ctx: Context | None
) -> t.Any:
# Match through normalization and case sensitivity
# first do token_normalize_func, then lowercase
# preserve original `value` to produce an accurate message in
# `self.fail`
normed_value = value
normed_choices = {choice: choice for choice in self.choices}

if ctx is not None and ctx.token_normalize_func is not None:
normed_value = ctx.token_normalize_func(value)
normed_choices = {
ctx.token_normalize_func(normed_choice): original
for normed_choice, original in normed_choices.items()
}

if not self.case_sensitive:
normed_value = normed_value.casefold()
normed_choices = {
normed_choice.casefold(): original
for normed_choice, original in normed_choices.items()
}

if normed_value in normed_choices:
return normed_choices[normed_value]
) -> ParamTypeValue:
"""
For a given value from the parser, normalize it and find its
matching normalized value in the list of choices. Then return the
matched "original" choice.
"""
normed_value = self.normalize_choice(choice=value, ctx=ctx)
normalized_mapping = self._normalized_mapping(ctx=ctx)

self.fail(self.get_invalid_choice_message(value), param, ctx)
try:
return next(
original
for original, normalized in normalized_mapping.items()
if normalized == normed_value
)
except StopIteration:
self.fail(
self.get_invalid_choice_message(value=value, ctx=ctx),
param=param,
ctx=ctx,
)

def get_invalid_choice_message(self, value: t.Any) -> str:
def get_invalid_choice_message(self, value: t.Any, ctx: Context | None) -> str:
"""Get the error message when the given choice is invalid.

:param value: The invalid value.

.. versionadded:: 8.2
"""
choices_str = ", ".join(map(repr, self.choices))
choices_str = ", ".join(map(repr, self._normalized_mapping(ctx=ctx).values()))
return ngettext(
"{value!r} is not {choice}.",
"{value!r} is not one of {choices}.",
Expand Down Expand Up @@ -382,7 +433,7 @@ def to_info_dict(self) -> dict[str, t.Any]:
info_dict["formats"] = self.formats
return info_dict

def get_metavar(self, param: Parameter) -> str:
def get_metavar(self, param: Parameter, ctx: Context) -> str | None:
return f"[{'|'.join(self.formats)}]"

def _try_to_convert_date(self, value: t.Any, format: str) -> datetime | None:
Expand Down
Loading
Loading