Skip to content

Commit 5c8e7fc

Browse files
authored
Fix list[dict] regression (#228)
* Fix `list[dict]` edge case * Bump version * Take UseAppendAction into account * ruff * docs sync * Fix UseAppendAction properly * cast for mypy * ruff * nits
1 parent 0c2cb26 commit 5c8e7fc

File tree

12 files changed

+160
-28
lines changed

12 files changed

+160
-28
lines changed

docs/source/examples/custom_constructors.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -343,13 +343,13 @@ structs.
343343
name="lower",
344344
type=int,
345345
default=default[0],
346-
helptext="Lower bound." "",
346+
helptext="Lower bound.",
347347
),
348348
tyro.constructors.StructFieldSpec(
349349
name="upper",
350350
type=int,
351351
default=default[1],
352-
helptext="Upper bound." "",
352+
helptext="Upper bound.",
353353
),
354354
),
355355
)

examples/06_custom_constructors/04_struct_registry.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,13 @@ def _(
6060
name="lower",
6161
type=int,
6262
default=default[0],
63-
helptext="Lower bound." "",
63+
helptext="Lower bound.",
6464
),
6565
tyro.constructors.StructFieldSpec(
6666
name="upper",
6767
type=int,
6868
default=default[1],
69-
helptext="Upper bound." "",
69+
helptext="Upper bound.",
7070
),
7171
),
7272
)

src/tyro/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import TYPE_CHECKING
22

3-
__version__ = "0.9.7"
3+
__version__ = "0.9.8"
44

55

66
from . import conf as conf

src/tyro/_fields.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -182,10 +182,7 @@ def is_positional_call(self) -> bool:
182182
def is_struct_type(typ: Union[TypeForm[Any], Callable], default_instance: Any) -> bool:
183183
"""Determine whether a type should be treated as a 'struct type', where a single
184184
type can be broken down into multiple fields (eg for nested dataclasses or
185-
classes).
186-
187-
TODO: we should come up with a better name than 'struct type', which is a little bit
188-
misleading."""
185+
classes)."""
189186

190187
list_or_error = field_list_from_type_or_callable(
191188
typ, default_instance, support_single_arg_types=False

src/tyro/_parsers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def handle_field(
373373

374374
# (2) Handle nested callables.
375375
if force_primitive == "struct" or _fields.is_struct_type(
376-
field.type_stripped, field.default
376+
field.type, field.default
377377
):
378378
field = field.with_new_type_stripped(
379379
_resolver.narrow_subtypes(

src/tyro/conf/_confstruct.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,9 @@ def subcommand(
116116
will be used in place of the argument's type for parsing arguments.
117117
For more configurability, see :mod:`tyro.constructors`.
118118
"""
119-
assert not (
120-
constructor is not None and constructor_factory is not None
121-
), "`constructor` and `constructor_factory` cannot both be set."
119+
assert not (constructor is not None and constructor_factory is not None), (
120+
"`constructor` and `constructor_factory` cannot both be set."
121+
)
122122
return _SubcommandConfig(
123123
name,
124124
default,
@@ -220,9 +220,9 @@ def arg(
220220
Returns:
221221
Object to attach via `typing.Annotated[]`.
222222
"""
223-
assert not (
224-
constructor is not None and constructor_factory is not None
225-
), "`constructor` and `constructor_factory` cannot both be set."
223+
assert not (constructor is not None and constructor_factory is not None), (
224+
"`constructor` and `constructor_factory` cannot both be set."
225+
)
226226

227227
if aliases is not None:
228228
for alias in aliases:

src/tyro/constructors/_primitive_spec.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -331,11 +331,15 @@ def sequence_rule(
331331
if container_type is collections.abc.Sequence:
332332
container_type = list
333333

334+
args = get_args(type_info.type)
334335
if container_type is tuple:
335-
(contained_type, ell) = get_args(type_info.type)
336+
assert len(args) == 2
337+
(contained_type, ell) = args
336338
assert ell == Ellipsis
339+
elif len(args) == 1:
340+
(contained_type,) = args
337341
else:
338-
(contained_type,) = get_args(type_info.type)
342+
contained_type = Any
339343

340344
inner_spec = ConstructorRegistry.get_primitive_spec(
341345
PrimitiveTypeInfo.make(
@@ -687,7 +691,9 @@ def str_from_instance(instance: Any) -> List[str]:
687691
if fuzzy_match is not None:
688692
return fuzzy_match.str_from_instance(instance)
689693

690-
assert False, f"could not match default value {instance} with any types in union {options}"
694+
assert False, (
695+
f"could not match default value {instance} with any types in union {options}"
696+
)
691697

692698
return PrimitiveConstructorSpec(
693699
nargs=nargs,

src/tyro/constructors/_struct_spec.py

+30-7
Original file line numberDiff line numberDiff line change
@@ -244,9 +244,9 @@ def typeddict_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
244244

245245
if typ_origin in (Required, NotRequired):
246246
args = get_args(typ)
247-
assert (
248-
len(args) == 1
249-
), "typing.Required[] and typing.NotRequired[T] require a concrete type T."
247+
assert len(args) == 1, (
248+
"typing.Required[] and typing.NotRequired[T] require a concrete type T."
249+
)
250250
typ = args[0]
251251
del args
252252

@@ -377,7 +377,9 @@ def namedtuple_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
377377
return StructConstructorSpec(instantiate=info.type, fields=tuple(field_list))
378378

379379
@registry.struct_rule
380-
def sequence_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
380+
def variable_length_sequence_rule(
381+
info: StructTypeInfo,
382+
) -> StructConstructorSpec | None:
381383
if get_origin(info.type) not in (
382384
list,
383385
set,
@@ -387,13 +389,34 @@ def sequence_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
387389
) or not isinstance(info.default, Iterable):
388390
return None
389391

390-
contained_type = get_args(info.type)[0] if get_args(info.type) else Any
392+
# Cast is for mypy.
393+
contained_type = cast(
394+
type, get_args(info.type)[0] if get_args(info.type) else Any
395+
)
391396

392397
# If the inner type is a primitive, we'll just treat the whole type as
393398
# a primitive.
394-
from ._registry import ConstructorRegistry
399+
from ._registry import (
400+
ConstructorRegistry,
401+
PrimitiveConstructorSpec,
402+
PrimitiveTypeInfo,
403+
)
395404

396-
if ConstructorRegistry._is_primitive_type(contained_type, set(info.markers)):
405+
contained_primitive_spec = ConstructorRegistry.get_primitive_spec(
406+
PrimitiveTypeInfo.make(contained_type, set(info.markers))
407+
)
408+
if (
409+
isinstance(contained_primitive_spec, PrimitiveConstructorSpec)
410+
# Why do we check nargs?
411+
# Because for primitives, we can't nest variable-length collections.
412+
#
413+
# For example, list[list[str]] can't be parsed as a single primitive.
414+
#
415+
# However, list[list[str]] can be parsed if the outer type is
416+
# handled as a struct (and a default value is provided, which we
417+
# check above).
418+
and contained_primitive_spec.nargs != "*"
419+
):
397420
return None
398421

399422
field_list = []

tests/test_collections.py

+53
Original file line numberDiff line numberDiff line change
@@ -653,3 +653,56 @@ def test_list_narrowing_direct() -> None:
653653
def test_tuple_direct() -> None:
654654
assert tyro.cli(Tuple[int, ...], args="1 2".split(" ")) == (1, 2) # type: ignore
655655
assert tyro.cli(Tuple[int, int], args="1 2".split(" ")) == (1, 2) # type: ignore
656+
657+
658+
def test_nested_dict_in_list() -> None:
659+
"""https://github.com/nerfstudio-project/nerfstudio/pull/3567"""
660+
661+
@dataclasses.dataclass
662+
class Args:
663+
proposal_net_args_list: List[Dict] = dataclasses.field(
664+
default_factory=lambda: [
665+
{
666+
"hidden_dim": 16,
667+
},
668+
{
669+
"hidden_dim": 16,
670+
},
671+
]
672+
)
673+
proposal_net_args_list2: Tuple[Dict[str, List], Dict[str, List]] = (
674+
dataclasses.field(
675+
default_factory=lambda: (
676+
{
677+
"hidden_dim": [16, 32],
678+
},
679+
{
680+
"hidden_dim": [16, 32],
681+
},
682+
)
683+
)
684+
)
685+
686+
assert tyro.cli(Args, args=[]) == Args()
687+
assert tyro.cli(
688+
Args, args=["--proposal-net-args-list.0.hidden-dim", "32"]
689+
).proposal_net_args_list == (
690+
[
691+
{
692+
"hidden_dim": 32,
693+
},
694+
{
695+
"hidden_dim": 16,
696+
},
697+
]
698+
)
699+
assert tyro.cli(
700+
Args, args=["--proposal-net-args-list2.1.hidden-dim", "32", "64"]
701+
).proposal_net_args_list2 == (
702+
{
703+
"hidden_dim": [16, 32],
704+
},
705+
{
706+
"hidden_dim": [32, 64],
707+
},
708+
)

tests/test_nested.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ class ModelSettings:
282282
tyro.conf.OmitSubcommandPrefixes[
283283
tyro.conf.ConsolidateSubcommandArgs[ModelSettings]
284284
],
285-
args=("output-head-settings None" " --number-of-outputs 5".split(" ")),
285+
args=("output-head-settings None --number-of-outputs 5".split(" ")),
286286
) == ModelSettings(OutputHeadSettings(5), None)
287287

288288
assert tyro.cli(

tests/test_py311_generated/test_collections_generated.py

+53
Original file line numberDiff line numberDiff line change
@@ -652,3 +652,56 @@ def test_list_narrowing_direct() -> None:
652652
def test_tuple_direct() -> None:
653653
assert tyro.cli(Tuple[int, ...], args="1 2".split(" ")) == (1, 2) # type: ignore
654654
assert tyro.cli(Tuple[int, int], args="1 2".split(" ")) == (1, 2) # type: ignore
655+
656+
657+
def test_nested_dict_in_list() -> None:
658+
"""https://github.com/nerfstudio-project/nerfstudio/pull/3567"""
659+
660+
@dataclasses.dataclass
661+
class Args:
662+
proposal_net_args_list: List[Dict] = dataclasses.field(
663+
default_factory=lambda: [
664+
{
665+
"hidden_dim": 16,
666+
},
667+
{
668+
"hidden_dim": 16,
669+
},
670+
]
671+
)
672+
proposal_net_args_list2: Tuple[Dict[str, List], Dict[str, List]] = (
673+
dataclasses.field(
674+
default_factory=lambda: (
675+
{
676+
"hidden_dim": [16, 32],
677+
},
678+
{
679+
"hidden_dim": [16, 32],
680+
},
681+
)
682+
)
683+
)
684+
685+
assert tyro.cli(Args, args=[]) == Args()
686+
assert tyro.cli(
687+
Args, args=["--proposal-net-args-list.0.hidden-dim", "32"]
688+
).proposal_net_args_list == (
689+
[
690+
{
691+
"hidden_dim": 32,
692+
},
693+
{
694+
"hidden_dim": 16,
695+
},
696+
]
697+
)
698+
assert tyro.cli(
699+
Args, args=["--proposal-net-args-list2.1.hidden-dim", "32", "64"]
700+
).proposal_net_args_list2 == (
701+
{
702+
"hidden_dim": [16, 32],
703+
},
704+
{
705+
"hidden_dim": [32, 64],
706+
},
707+
)

tests/test_py311_generated/test_nested_generated.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ class ModelSettings:
292292
tyro.conf.OmitSubcommandPrefixes[
293293
tyro.conf.ConsolidateSubcommandArgs[ModelSettings]
294294
],
295-
args=("output-head-settings None" " --number-of-outputs 5".split(" ")),
295+
args=("output-head-settings None --number-of-outputs 5".split(" ")),
296296
) == ModelSettings(OutputHeadSettings(5), None)
297297

298298
assert tyro.cli(

0 commit comments

Comments
 (0)