Skip to content

Commit bd54f77

Browse files
authored
(refactor) Context-based resolution for generic types (#180)
* Start refactor (broken) * Fixes * Fix generic shim edge case * New alias tests, edge case fixes * Add type ignore * CI fixes * Revert example change * Fix example numbering * Appease mypy * Remove unnecessary type alias unwrap
1 parent 89ff57a commit bd54f77

16 files changed

+830
-418
lines changed

docs/source/examples/04_additional/12_counters.rst docs/source/examples/04_additional/13_counters.rst

+8-8
Original file line numberDiff line numberDiff line change
@@ -38,30 +38,30 @@ Repeatable 'counter' arguments can be specified via :data:`tyro.conf.UseCounterA
3838

3939
.. raw:: html
4040

41-
<kbd>python 04_additional/12_counters.py --help</kbd>
41+
<kbd>python 04_additional/13_counters.py --help</kbd>
4242

43-
.. program-output:: python ../../examples/04_additional/12_counters.py --help
43+
.. program-output:: python ../../examples/04_additional/13_counters.py --help
4444

4545
------------
4646

4747
.. raw:: html
4848

49-
<kbd>python 04_additional/12_counters.py --verbosity</kbd>
49+
<kbd>python 04_additional/13_counters.py --verbosity</kbd>
5050

51-
.. program-output:: python ../../examples/04_additional/12_counters.py --verbosity
51+
.. program-output:: python ../../examples/04_additional/13_counters.py --verbosity
5252

5353
------------
5454

5555
.. raw:: html
5656

57-
<kbd>python 04_additional/12_counters.py --verbosity --verbosity</kbd>
57+
<kbd>python 04_additional/13_counters.py --verbosity --verbosity</kbd>
5858

59-
.. program-output:: python ../../examples/04_additional/12_counters.py --verbosity --verbosity
59+
.. program-output:: python ../../examples/04_additional/13_counters.py --verbosity --verbosity
6060

6161
------------
6262

6363
.. raw:: html
6464

65-
<kbd>python 04_additional/12_counters.py -vvv</kbd>
65+
<kbd>python 04_additional/13_counters.py -vvv</kbd>
6666

67-
.. program-output:: python ../../examples/04_additional/12_counters.py -vvv
67+
.. program-output:: python ../../examples/04_additional/13_counters.py -vvv

docs/source/examples/04_additional/13_type_statement.rst docs/source/examples/04_additional/16_type_statement.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,6 @@ In Python 3.12, the :code:`type` statement is introduced to create type aliases.
4545

4646
.. raw:: html
4747

48-
<kbd>python 04_additional/13_type_statement.py --help</kbd>
48+
<kbd>python 04_additional/16_type_statement.py --help</kbd>
4949

50-
.. program-output:: python ../../examples/04_additional/13_type_statement.py --help
50+
.. program-output:: python ../../examples/04_additional/16_type_statement.py --help
File renamed without changes.

src/tyro/_arguments.py

+2-11
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
TYPE_CHECKING,
1313
Any,
1414
Callable,
15-
Dict,
1615
Iterable,
1716
Mapping,
1817
Optional,
@@ -27,8 +26,7 @@
2726
import shtab
2827

2928
from . import _argparse as argparse
30-
from . import _fields, _instantiators, _resolver, _strings
31-
from ._typing import TypeForm
29+
from . import _fields, _instantiators, _strings
3230
from .conf import _markers
3331

3432
if TYPE_CHECKING:
@@ -108,7 +106,6 @@ class ArgumentDefinition:
108106
extern_prefix: str # User-facing prefix.
109107
subcommand_prefix: str # Prefix for nesting.
110108
field: _fields.FieldDefinition
111-
type_from_typevar: Dict[TypeVar, TypeForm[Any]]
112109

113110
def add_argument(
114111
self, parser: Union[argparse.ArgumentParser, argparse._ArgumentGroup]
@@ -254,12 +251,7 @@ def _rule_handle_boolean_flags(
254251
arg: ArgumentDefinition,
255252
lowered: LoweredArgumentDefinition,
256253
) -> None:
257-
if (
258-
_resolver.apply_type_from_typevar(
259-
arg.field.type_or_callable, arg.type_from_typevar
260-
)
261-
is not bool
262-
):
254+
if arg.field.type_or_callable is not bool:
263255
return
264256

265257
if (
@@ -305,7 +297,6 @@ def _rule_recursive_instantiator_from_type(
305297
try:
306298
instantiator, metadata = _instantiators.instantiator_from_type(
307299
arg.field.type_or_callable,
308-
arg.type_from_typevar,
309300
arg.field.markers,
310301
)
311302
except _instantiators.UnsupportedTypeAnnotationError as e:

src/tyro/_cli.py

+101-94
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import shtab
2121
from typing_extensions import Literal
2222

23+
from tyro._resolver import TypeParamResolver
24+
2325
from . import _argparse as argparse
2426
from . import (
2527
_argparse_formatter,
@@ -316,104 +318,109 @@ def _cli_impl(
316318
stacklevel=2,
317319
)
318320

319-
# Internally, we distinguish between two concepts:
320-
# - "default", which is used for individual arguments.
321-
# - "default_instance", which is used for _fields_ (which may be broken down into
322-
# one or many arguments, depending on various factors).
323-
#
324-
# This could be revisited.
325-
default_instance_internal: Union[_fields.NonpropagatingMissingType, OutT] = (
326-
_fields.MISSING_NONPROP if default is None else default
327-
)
328-
329-
# We wrap our type with a dummy dataclass if it can't be treated as a nested type.
330-
# For example: passing in f=int will result in a dataclass with a single field
331-
# typed as int.
332-
if not _fields.is_nested_type(cast(type, f), default_instance_internal):
333-
dummy_field = cast(
334-
dataclasses.Field,
335-
dataclasses.field(),
321+
resolve_context = TypeParamResolver.get_assignment_context(f)
322+
with resolve_context:
323+
f = resolve_context.origin_type
324+
325+
# Internally, we distinguish between two concepts:
326+
# - "default", which is used for individual arguments.
327+
# - "default_instance", which is used for _fields_ (which may be broken down into
328+
# one or many arguments, depending on various factors).
329+
#
330+
# This could be revisited.
331+
default_instance_internal: Union[_fields.NonpropagatingMissingType, OutT] = (
332+
_fields.MISSING_NONPROP if default is None else default
336333
)
337-
f = dataclasses.make_dataclass(
338-
cls_name="dummy",
339-
fields=[(_strings.dummy_field_name, cast(type, f), dummy_field)],
340-
frozen=True,
341-
)
342-
default_instance_internal = f(default_instance_internal) # type: ignore
343-
dummy_wrapped = True
344-
else:
345-
dummy_wrapped = False
346-
347-
# Read and fix arguments. If the user passes in --field_name instead of
348-
# --field-name, correct for them.
349-
args = list(sys.argv[1:]) if args is None else list(args)
350-
351-
# Fix arguments. This will modify all option-style arguments replacing
352-
# underscores with hyphens, or vice versa if use_underscores=True.
353-
# If two options are ambiguous, e.g., --a_b and --a-b, raise a runtime error.
354-
modified_args: Dict[str, str] = {}
355-
for index, arg in enumerate(args):
356-
if not arg.startswith("--"):
357-
continue
358-
359-
if "=" in arg:
360-
arg, _, val = arg.partition("=")
361-
fixed = "--" + _strings.replace_delimeter_in_part(arg[2:]) + "=" + val
334+
335+
# We wrap our type with a dummy dataclass if it can't be treated as a nested type.
336+
# For example: passing in f=int will result in a dataclass with a single field
337+
# typed as int.
338+
if not _fields.is_nested_type(cast(type, f), default_instance_internal):
339+
dummy_field = cast(
340+
dataclasses.Field,
341+
dataclasses.field(),
342+
)
343+
f = dataclasses.make_dataclass(
344+
cls_name="dummy",
345+
fields=[(_strings.dummy_field_name, cast(type, f), dummy_field)],
346+
frozen=True,
347+
)
348+
default_instance_internal = f(default_instance_internal) # type: ignore
349+
dummy_wrapped = True
362350
else:
363-
fixed = "--" + _strings.replace_delimeter_in_part(arg[2:])
364-
if (
365-
return_unknown_args
366-
and fixed in modified_args
367-
and modified_args[fixed] != arg
368-
):
369-
raise RuntimeError(
370-
"Ambiguous arguments: " + modified_args[fixed] + " and " + arg
351+
dummy_wrapped = False
352+
353+
# Read and fix arguments. If the user passes in --field_name instead of
354+
# --field-name, correct for them.
355+
args = list(sys.argv[1:]) if args is None else list(args)
356+
357+
# Fix arguments. This will modify all option-style arguments replacing
358+
# underscores with hyphens, or vice versa if use_underscores=True.
359+
# If two options are ambiguous, e.g., --a_b and --a-b, raise a runtime error.
360+
modified_args: Dict[str, str] = {}
361+
for index, arg in enumerate(args):
362+
if not arg.startswith("--"):
363+
continue
364+
365+
if "=" in arg:
366+
arg, _, val = arg.partition("=")
367+
fixed = "--" + _strings.replace_delimeter_in_part(arg[2:]) + "=" + val
368+
else:
369+
fixed = "--" + _strings.replace_delimeter_in_part(arg[2:])
370+
if (
371+
return_unknown_args
372+
and fixed in modified_args
373+
and modified_args[fixed] != arg
374+
):
375+
raise RuntimeError(
376+
"Ambiguous arguments: " + modified_args[fixed] + " and " + arg
377+
)
378+
modified_args[fixed] = arg
379+
args[index] = fixed
380+
381+
# If we pass in the --tyro-print-completion or --tyro-write-completion flags: turn
382+
# formatting tags, and get the shell we want to generate a completion script for
383+
# (bash/zsh/tcsh).
384+
#
385+
# shtab also offers an add_argument_to() functions that fulfills a similar goal, but
386+
# manual parsing of argv is convenient for turning off formatting.
387+
#
388+
# Note: --tyro-print-completion is deprecated! --tyro-write-completion is less prone
389+
# to errors from accidental logging, print statements, etc.
390+
print_completion = False
391+
write_completion = False
392+
if len(args) >= 2:
393+
# We replace underscores with hyphens to accomodate for `use_undercores`.
394+
print_completion = args[0].replace("_", "-") == "--tyro-print-completion"
395+
write_completion = (
396+
len(args) >= 3
397+
and args[0].replace("_", "-") == "--tyro-write-completion"
371398
)
372-
modified_args[fixed] = arg
373-
args[index] = fixed
374-
375-
# If we pass in the --tyro-print-completion or --tyro-write-completion flags: turn
376-
# formatting tags, and get the shell we want to generate a completion script for
377-
# (bash/zsh/tcsh).
378-
#
379-
# shtab also offers an add_argument_to() functions that fulfills a similar goal, but
380-
# manual parsing of argv is convenient for turning off formatting.
381-
#
382-
# Note: --tyro-print-completion is deprecated! --tyro-write-completion is less prone
383-
# to errors from accidental logging, print statements, etc.
384-
print_completion = False
385-
write_completion = False
386-
if len(args) >= 2:
387-
# We replace underscores with hyphens to accomodate for `use_undercores`.
388-
print_completion = args[0].replace("_", "-") == "--tyro-print-completion"
389-
write_completion = (
390-
len(args) >= 3 and args[0].replace("_", "-") == "--tyro-write-completion"
391-
)
392399

393-
# Note: setting USE_RICH must happen before the parser specification is generated.
394-
# TODO: revisit this. Ideally we should be able to eliminate the global state
395-
# changes.
396-
completion_shell = None
397-
completion_target_path = None
398-
if print_completion or write_completion:
399-
completion_shell = args[1]
400-
if write_completion:
401-
completion_target_path = pathlib.Path(args[2])
402-
if print_completion or write_completion or return_parser:
403-
_arguments.USE_RICH = False
404-
else:
405-
_arguments.USE_RICH = True
406-
407-
# Map a callable to the relevant CLI arguments + subparsers.
408-
parser_spec = _parsers.ParserSpecification.from_callable_or_type(
409-
f,
410-
description=description,
411-
parent_classes=set(), # Used for recursive calls.
412-
default_instance=default_instance_internal, # Overrides for default values.
413-
intern_prefix="", # Used for recursive calls.
414-
extern_prefix="", # Used for recursive calls.
415-
subcommand_prefix="", # Used for recursive calls.
416-
)
400+
# Note: setting USE_RICH must happen before the parser specification is generated.
401+
# TODO: revisit this. Ideally we should be able to eliminate the global state
402+
# changes.
403+
completion_shell = None
404+
completion_target_path = None
405+
if print_completion or write_completion:
406+
completion_shell = args[1]
407+
if write_completion:
408+
completion_target_path = pathlib.Path(args[2])
409+
if print_completion or write_completion or return_parser:
410+
_arguments.USE_RICH = False
411+
else:
412+
_arguments.USE_RICH = True
413+
414+
# Map a callable to the relevant CLI arguments + subparsers.
415+
parser_spec = _parsers.ParserSpecification.from_callable_or_type(
416+
f,
417+
description=description,
418+
parent_classes=set(), # Used for recursive calls.
419+
default_instance=default_instance_internal, # Overrides for default values.
420+
intern_prefix="", # Used for recursive calls.
421+
extern_prefix="", # Used for recursive calls.
422+
subcommand_prefix="", # Used for recursive calls.
423+
)
417424

418425
# Generate parser!
419426
with _argparse_formatter.ansi_context():

src/tyro/_docstrings.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def get_callable_description(f: Callable) -> str:
301301
the fields of the class if a docstring is not specified; this helper will ignore
302302
these docstrings."""
303303

304-
f, _unused = _resolver.resolve_generic_types(f)
304+
f, _ = _resolver.resolve_generic_types(f)
305305
f = _resolver.unwrap_origin_strip_extras(f)
306306
if f in _callable_description_blocklist:
307307
return ""

0 commit comments

Comments
 (0)