|
20 | 20 | import shtab
|
21 | 21 | from typing_extensions import Literal
|
22 | 22 |
|
| 23 | +from tyro._resolver import TypeParamResolver |
| 24 | + |
23 | 25 | from . import _argparse as argparse
|
24 | 26 | from . import (
|
25 | 27 | _argparse_formatter,
|
@@ -316,104 +318,109 @@ def _cli_impl(
|
316 | 318 | stacklevel=2,
|
317 | 319 | )
|
318 | 320 |
|
319 |
| - # Internally, we distinguish between two concepts: |
320 |
| - # - "default", which is used for individual arguments. |
321 |
| - # - "default_instance", which is used for _fields_ (which may be broken down into |
322 |
| - # one or many arguments, depending on various factors). |
323 |
| - # |
324 |
| - # This could be revisited. |
325 |
| - default_instance_internal: Union[_fields.NonpropagatingMissingType, OutT] = ( |
326 |
| - _fields.MISSING_NONPROP if default is None else default |
327 |
| - ) |
328 |
| - |
329 |
| - # We wrap our type with a dummy dataclass if it can't be treated as a nested type. |
330 |
| - # For example: passing in f=int will result in a dataclass with a single field |
331 |
| - # typed as int. |
332 |
| - if not _fields.is_nested_type(cast(type, f), default_instance_internal): |
333 |
| - dummy_field = cast( |
334 |
| - dataclasses.Field, |
335 |
| - dataclasses.field(), |
| 321 | + resolve_context = TypeParamResolver.get_assignment_context(f) |
| 322 | + with resolve_context: |
| 323 | + f = resolve_context.origin_type |
| 324 | + |
| 325 | + # Internally, we distinguish between two concepts: |
| 326 | + # - "default", which is used for individual arguments. |
| 327 | + # - "default_instance", which is used for _fields_ (which may be broken down into |
| 328 | + # one or many arguments, depending on various factors). |
| 329 | + # |
| 330 | + # This could be revisited. |
| 331 | + default_instance_internal: Union[_fields.NonpropagatingMissingType, OutT] = ( |
| 332 | + _fields.MISSING_NONPROP if default is None else default |
336 | 333 | )
|
337 |
| - f = dataclasses.make_dataclass( |
338 |
| - cls_name="dummy", |
339 |
| - fields=[(_strings.dummy_field_name, cast(type, f), dummy_field)], |
340 |
| - frozen=True, |
341 |
| - ) |
342 |
| - default_instance_internal = f(default_instance_internal) # type: ignore |
343 |
| - dummy_wrapped = True |
344 |
| - else: |
345 |
| - dummy_wrapped = False |
346 |
| - |
347 |
| - # Read and fix arguments. If the user passes in --field_name instead of |
348 |
| - # --field-name, correct for them. |
349 |
| - args = list(sys.argv[1:]) if args is None else list(args) |
350 |
| - |
351 |
| - # Fix arguments. This will modify all option-style arguments replacing |
352 |
| - # underscores with hyphens, or vice versa if use_underscores=True. |
353 |
| - # If two options are ambiguous, e.g., --a_b and --a-b, raise a runtime error. |
354 |
| - modified_args: Dict[str, str] = {} |
355 |
| - for index, arg in enumerate(args): |
356 |
| - if not arg.startswith("--"): |
357 |
| - continue |
358 |
| - |
359 |
| - if "=" in arg: |
360 |
| - arg, _, val = arg.partition("=") |
361 |
| - fixed = "--" + _strings.replace_delimeter_in_part(arg[2:]) + "=" + val |
| 334 | + |
| 335 | + # We wrap our type with a dummy dataclass if it can't be treated as a nested type. |
| 336 | + # For example: passing in f=int will result in a dataclass with a single field |
| 337 | + # typed as int. |
| 338 | + if not _fields.is_nested_type(cast(type, f), default_instance_internal): |
| 339 | + dummy_field = cast( |
| 340 | + dataclasses.Field, |
| 341 | + dataclasses.field(), |
| 342 | + ) |
| 343 | + f = dataclasses.make_dataclass( |
| 344 | + cls_name="dummy", |
| 345 | + fields=[(_strings.dummy_field_name, cast(type, f), dummy_field)], |
| 346 | + frozen=True, |
| 347 | + ) |
| 348 | + default_instance_internal = f(default_instance_internal) # type: ignore |
| 349 | + dummy_wrapped = True |
362 | 350 | else:
|
363 |
| - fixed = "--" + _strings.replace_delimeter_in_part(arg[2:]) |
364 |
| - if ( |
365 |
| - return_unknown_args |
366 |
| - and fixed in modified_args |
367 |
| - and modified_args[fixed] != arg |
368 |
| - ): |
369 |
| - raise RuntimeError( |
370 |
| - "Ambiguous arguments: " + modified_args[fixed] + " and " + arg |
| 351 | + dummy_wrapped = False |
| 352 | + |
| 353 | + # Read and fix arguments. If the user passes in --field_name instead of |
| 354 | + # --field-name, correct for them. |
| 355 | + args = list(sys.argv[1:]) if args is None else list(args) |
| 356 | + |
| 357 | + # Fix arguments. This will modify all option-style arguments replacing |
| 358 | + # underscores with hyphens, or vice versa if use_underscores=True. |
| 359 | + # If two options are ambiguous, e.g., --a_b and --a-b, raise a runtime error. |
| 360 | + modified_args: Dict[str, str] = {} |
| 361 | + for index, arg in enumerate(args): |
| 362 | + if not arg.startswith("--"): |
| 363 | + continue |
| 364 | + |
| 365 | + if "=" in arg: |
| 366 | + arg, _, val = arg.partition("=") |
| 367 | + fixed = "--" + _strings.replace_delimeter_in_part(arg[2:]) + "=" + val |
| 368 | + else: |
| 369 | + fixed = "--" + _strings.replace_delimeter_in_part(arg[2:]) |
| 370 | + if ( |
| 371 | + return_unknown_args |
| 372 | + and fixed in modified_args |
| 373 | + and modified_args[fixed] != arg |
| 374 | + ): |
| 375 | + raise RuntimeError( |
| 376 | + "Ambiguous arguments: " + modified_args[fixed] + " and " + arg |
| 377 | + ) |
| 378 | + modified_args[fixed] = arg |
| 379 | + args[index] = fixed |
| 380 | + |
| 381 | + # If we pass in the --tyro-print-completion or --tyro-write-completion flags: turn |
| 382 | + # formatting tags, and get the shell we want to generate a completion script for |
| 383 | + # (bash/zsh/tcsh). |
| 384 | + # |
| 385 | + # shtab also offers an add_argument_to() functions that fulfills a similar goal, but |
| 386 | + # manual parsing of argv is convenient for turning off formatting. |
| 387 | + # |
| 388 | + # Note: --tyro-print-completion is deprecated! --tyro-write-completion is less prone |
| 389 | + # to errors from accidental logging, print statements, etc. |
| 390 | + print_completion = False |
| 391 | + write_completion = False |
| 392 | + if len(args) >= 2: |
| 393 | + # We replace underscores with hyphens to accomodate for `use_undercores`. |
| 394 | + print_completion = args[0].replace("_", "-") == "--tyro-print-completion" |
| 395 | + write_completion = ( |
| 396 | + len(args) >= 3 |
| 397 | + and args[0].replace("_", "-") == "--tyro-write-completion" |
371 | 398 | )
|
372 |
| - modified_args[fixed] = arg |
373 |
| - args[index] = fixed |
374 |
| - |
375 |
| - # If we pass in the --tyro-print-completion or --tyro-write-completion flags: turn |
376 |
| - # formatting tags, and get the shell we want to generate a completion script for |
377 |
| - # (bash/zsh/tcsh). |
378 |
| - # |
379 |
| - # shtab also offers an add_argument_to() functions that fulfills a similar goal, but |
380 |
| - # manual parsing of argv is convenient for turning off formatting. |
381 |
| - # |
382 |
| - # Note: --tyro-print-completion is deprecated! --tyro-write-completion is less prone |
383 |
| - # to errors from accidental logging, print statements, etc. |
384 |
| - print_completion = False |
385 |
| - write_completion = False |
386 |
| - if len(args) >= 2: |
387 |
| - # We replace underscores with hyphens to accomodate for `use_undercores`. |
388 |
| - print_completion = args[0].replace("_", "-") == "--tyro-print-completion" |
389 |
| - write_completion = ( |
390 |
| - len(args) >= 3 and args[0].replace("_", "-") == "--tyro-write-completion" |
391 |
| - ) |
392 | 399 |
|
393 |
| - # Note: setting USE_RICH must happen before the parser specification is generated. |
394 |
| - # TODO: revisit this. Ideally we should be able to eliminate the global state |
395 |
| - # changes. |
396 |
| - completion_shell = None |
397 |
| - completion_target_path = None |
398 |
| - if print_completion or write_completion: |
399 |
| - completion_shell = args[1] |
400 |
| - if write_completion: |
401 |
| - completion_target_path = pathlib.Path(args[2]) |
402 |
| - if print_completion or write_completion or return_parser: |
403 |
| - _arguments.USE_RICH = False |
404 |
| - else: |
405 |
| - _arguments.USE_RICH = True |
406 |
| - |
407 |
| - # Map a callable to the relevant CLI arguments + subparsers. |
408 |
| - parser_spec = _parsers.ParserSpecification.from_callable_or_type( |
409 |
| - f, |
410 |
| - description=description, |
411 |
| - parent_classes=set(), # Used for recursive calls. |
412 |
| - default_instance=default_instance_internal, # Overrides for default values. |
413 |
| - intern_prefix="", # Used for recursive calls. |
414 |
| - extern_prefix="", # Used for recursive calls. |
415 |
| - subcommand_prefix="", # Used for recursive calls. |
416 |
| - ) |
| 400 | + # Note: setting USE_RICH must happen before the parser specification is generated. |
| 401 | + # TODO: revisit this. Ideally we should be able to eliminate the global state |
| 402 | + # changes. |
| 403 | + completion_shell = None |
| 404 | + completion_target_path = None |
| 405 | + if print_completion or write_completion: |
| 406 | + completion_shell = args[1] |
| 407 | + if write_completion: |
| 408 | + completion_target_path = pathlib.Path(args[2]) |
| 409 | + if print_completion or write_completion or return_parser: |
| 410 | + _arguments.USE_RICH = False |
| 411 | + else: |
| 412 | + _arguments.USE_RICH = True |
| 413 | + |
| 414 | + # Map a callable to the relevant CLI arguments + subparsers. |
| 415 | + parser_spec = _parsers.ParserSpecification.from_callable_or_type( |
| 416 | + f, |
| 417 | + description=description, |
| 418 | + parent_classes=set(), # Used for recursive calls. |
| 419 | + default_instance=default_instance_internal, # Overrides for default values. |
| 420 | + intern_prefix="", # Used for recursive calls. |
| 421 | + extern_prefix="", # Used for recursive calls. |
| 422 | + subcommand_prefix="", # Used for recursive calls. |
| 423 | + ) |
417 | 424 |
|
418 | 425 | # Generate parser!
|
419 | 426 | with _argparse_formatter.ansi_context():
|
|
0 commit comments