Skip to content

Commit

Permalink
Consistent behavior for implicit and explicit selection of optional s…
Browse files Browse the repository at this point in the history
…ubcommands (#226)

* Consistent behavior for implicit and explicit selection of default subcommands, docs

* One more test

* Generate tests

* Bump version

* Housekeeping

* ruff

* spelling
  • Loading branch information
brentyi authored Jan 7, 2025
1 parent 31b8930 commit e7156a0
Show file tree
Hide file tree
Showing 11 changed files with 205 additions and 69 deletions.
4 changes: 0 additions & 4 deletions examples/06_custom_constructors/04_struct_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,13 @@ def _(
if isinstance(type_info.default, Bounds):
# If the default value is a `Bounds` instance, we don't need to generate a constructor.
default = (type_info.default.bounds[0], type_info.default.bounds[1])
is_default_overridden = True
else:
# Otherwise, the default value is missing. We'll mark the child defaults as missing as well.
assert type_info.default in (
tyro.constructors.MISSING,
tyro.constructors.MISSING_NONPROP,
)
default = (tyro.MISSING, tyro.MISSING)
is_default_overridden = False

# If the rule applies, we return the constructor spec.
return tyro.constructors.StructConstructorSpec(
Expand All @@ -62,14 +60,12 @@ def _(
name="lower",
type=int,
default=default[0],
is_default_overridden=is_default_overridden,
helptext="Lower bound." "",
),
tyro.constructors.StructFieldSpec(
name="upper",
type=int,
default=default[1],
is_default_overridden=is_default_overridden,
helptext="Upper bound." "",
),
),
Expand Down
2 changes: 1 addition & 1 deletion src/tyro/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING

__version__ = "0.9.6"
__version__ = "0.9.7"


from . import conf as conf
Expand Down
9 changes: 0 additions & 9 deletions src/tyro/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ class FieldDefinition:
"""Full type, including runtime annotations."""
type_stripped: TypeForm[Any] | Callable
default: Any
# We need to record whether defaults are from default instances to
# determine if they should override the default in
# tyro.conf.subcommand(default=...).
is_default_from_default_instance: bool
helptext: Optional[str]
markers: Set[Any]
custom_constructor: bool
Expand Down Expand Up @@ -66,7 +62,6 @@ def from_field_spec(field_spec: StructFieldSpec) -> FieldDefinition:
name=field_spec.name,
typ=field_spec.type,
default=field_spec.default,
is_default_from_default_instance=field_spec.is_default_overridden,
helptext=field_spec.helptext,
call_argname_override=field_spec._call_argname,
)
Expand All @@ -76,7 +71,6 @@ def make(
name: str,
typ: Union[TypeForm[Any], Callable],
default: Any,
is_default_from_default_instance: bool,
helptext: Optional[str],
call_argname_override: Optional[Any] = None,
):
Expand Down Expand Up @@ -133,7 +127,6 @@ def make(
type=typ,
type_stripped=type_stripped,
default=default,
is_default_from_default_instance=is_default_from_default_instance,
helptext=helptext,
markers=set(markers),
custom_constructor=argconf.constructor_factory is not None,
Expand Down Expand Up @@ -243,7 +236,6 @@ def field_list_from_type_or_callable(
name="value",
typ=f,
default=default_instance,
is_default_from_default_instance=True,
helptext="",
)
],
Expand Down Expand Up @@ -371,7 +363,6 @@ def _field_list_from_function(
if default_instance in MISSING_AND_MISSING_NONPROP
else Annotated[(typ, _markers._OPTIONAL_GROUP)], # type: ignore
default=default if default is not param.empty else MISSING_NONPROP,
is_default_from_default_instance=False,
helptext=helptext,
)
)
Expand Down
3 changes: 1 addition & 2 deletions src/tyro/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,8 +553,7 @@ def from_field(

# If names match, borrow subcommand default from field default.
if default_name == subcommand_name and (
field.is_default_from_default_instance
or subcommand_config.default in _singleton.MISSING_AND_MISSING_NONPROP
field.default not in _singleton.MISSING_AND_MISSING_NONPROP
):
subcommand_config = dataclasses.replace(
subcommand_config, default=field.default
Expand Down
26 changes: 22 additions & 4 deletions src/tyro/conf/_confstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def subcommand(
.. code-block:: python
tyro.cli(
Union[NestedTypeA, NestedTypeB]
Union[StructTypeA, StructTypeB]
)
This will create two subcommands: `nested-type-a` and `nested-type-b`.
Expand All @@ -75,18 +75,36 @@ def subcommand(
tyro.cli(
Union[
Annotated[
NestedTypeA, subcommand("a", ...)
StructTypeA, subcommand("a", ...)
],
Annotated[
NestedTypeB, subcommand("b", ...)
StructTypeB, subcommand("b", ...)
],
]
)
If we have a default value both in the annotation and attached to the field
itself (eg, RHS of `=` within function or dataclass signature), the field
default will take precedence.
.. code-block:: python
# For the first subcommand, StructType(1) will be used as the default.
# The second subcommand, whose type is inconsistent with the field
# default, will be unaffected.
x: Union[
Annotated[
StructTypeA, subcommand(default=StructTypeA(0)
],
Annotated[
StructTypeB, subcommand(default=StructTypeB(0)
],
] = StructTypeA(1)
Arguments:
name: The name of the subcommand in the CLI.
default: A default value for the subcommand, for struct-like types. (eg
dataclasses)
dataclasses).
description: Description of this option to use in the helptext. Defaults to
docstring.
prefix_name: Whether to prefix the name of the subcommand based on where it
Expand Down
60 changes: 19 additions & 41 deletions src/tyro/constructors/_struct_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,14 @@ class StructFieldSpec:
"""The type of the field. Can be either a primitive or a nested struct type."""
default: Any
"""The default value of the field."""
is_default_overridden: bool = False
"""Whether the default value was overridden by the default instance. Should
be set to False if the default value was assigned by the field itself."""
helptext: str | None = None
"""Helpjext for the field."""
# TODO: it's theoretically possible to override the argname with `None`.
_call_argname: Any = None
"""Private: the name of the argument to pass to the callable. This is used
for dictionary types."""
is_default_overridden: None = None
"""Deprecated. No longer used."""


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -174,9 +173,7 @@ def dataclass_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
if is_flax_module and dc_field.name in ("name", "parent"):
continue

default, is_default_from_default_instance = _get_dataclass_field_default(
dc_field, info.default
)
default = _get_dataclass_field_default(dc_field, info.default)

# Try to get helptext from field metadata. This is also intended to be
# compatible with HuggingFace-style config objects.
Expand All @@ -194,7 +191,6 @@ def dataclass_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
name=dc_field.name,
type=dc_field.type,
default=default,
is_default_overridden=is_default_from_default_instance,
helptext=helptext,
)
)
Expand Down Expand Up @@ -222,10 +218,8 @@ def typeddict_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
cls, include_extras=True
).items():
typ_origin = get_origin(typ)
is_default_from_default_instance = False
if valid_default_instance and name in cast(dict, info.default):
default = cast(dict, info.default)[name]
is_default_from_default_instance = True
elif typ_origin is Required and total is False:
# Support total=False.
default = MISSING
Expand Down Expand Up @@ -261,7 +255,6 @@ def typeddict_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
name=name,
type=typ,
default=default,
is_default_overridden=is_default_from_default_instance,
helptext=_docstrings.get_field_docstring(cls, name),
)
)
Expand Down Expand Up @@ -296,11 +289,9 @@ def attrs_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
# Default handling.
name = attr_field.name
default = attr_field.default
is_default_from_default_instance = False
if info.default not in MISSING_AND_MISSING_NONPROP:
assert hasattr(info.default, name)
default = getattr(info.default, name)
is_default_from_default_instance = True
elif default is attr.NOTHING:
default = MISSING_NONPROP
elif isinstance(default, attr.Factory): # type: ignore
Expand All @@ -312,7 +303,6 @@ def attrs_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
name=name,
type=attr_field.type,
default=default,
is_default_overridden=is_default_from_default_instance,
helptext=_docstrings.get_field_docstring(info.type, name),
)
)
Expand Down Expand Up @@ -345,7 +335,6 @@ def dict_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
name=str(k) if not isinstance(k, enum.Enum) else k.name,
type=type(v),
default=v,
is_default_overridden=True,
helptext=None,
_call_argname=k,
)
Expand All @@ -368,13 +357,11 @@ def namedtuple_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
info.type, include_extras=True
).items():
default = field_defaults.get(name, MISSING_NONPROP)
is_default_from_default_instance = False

if info.default not in MISSING_AND_MISSING_NONPROP and hasattr(
info.default, name
):
default = getattr(info.default, name)
is_default_from_default_instance = True
elif info.default is MISSING:
default = MISSING

Expand All @@ -383,7 +370,6 @@ def namedtuple_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
name=name,
type=typ,
default=default,
is_default_overridden=is_default_from_default_instance,
helptext=_docstrings.get_field_docstring(info.type, name),
)
)
Expand Down Expand Up @@ -417,7 +403,6 @@ def sequence_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
name=str(i),
type=cast(type, contained_type),
default=default_i,
is_default_overridden=True,
helptext="",
)
)
Expand Down Expand Up @@ -461,7 +446,6 @@ def tuple_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
name=str(i),
type=child,
default=default_i,
is_default_overridden=True,
helptext="",
)
)
Expand Down Expand Up @@ -537,17 +521,14 @@ def pydantic_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
info.type, pd1_field.name
)

default, is_default_from_default_instance = (
_get_pydantic_v1_field_default(
pd1_field.name, pd1_field, info.default
)
default = _get_pydantic_v1_field_default(
pd1_field.name, pd1_field, info.default
)
field_list.append(
StructFieldSpec(
name=pd1_field.name,
type=hints[pd1_field.name],
default=default,
is_default_overridden=is_default_from_default_instance,
helptext=helptext,
)
)
Expand All @@ -558,9 +539,7 @@ def pydantic_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
if helptext is None:
helptext = _docstrings.get_field_docstring(info.type, name)

default, is_default_from_default_instance = (
_get_pydantic_v2_field_default(name, pd2_field, info.default)
)
default = _get_pydantic_v2_field_default(name, pd2_field, info.default)
field_list.append(
StructFieldSpec(
name=name,
Expand All @@ -572,7 +551,6 @@ def pydantic_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
else pd2_field.annotation
),
default=default,
is_default_overridden=is_default_from_default_instance,
helptext=helptext,
)
)
Expand All @@ -597,12 +575,12 @@ def _ensure_dataclass_instance_used_as_default_is_frozen(

def _get_dataclass_field_default(
field: dataclasses.Field, parent_default_instance: Any
) -> tuple[Any, bool]:
) -> Any:
"""Helper for getting the default instance for a dataclass field."""
# If the dataclass's parent is explicitly marked MISSING, mark this field as missing
# as well.
if parent_default_instance is MISSING:
return MISSING, False
return MISSING

# Try grabbing default from parent instance.
if (
Expand All @@ -611,7 +589,7 @@ def _get_dataclass_field_default(
):
# Populate default from some parent, eg `default=` in `tyro.cli()`.
if hasattr(parent_default_instance, field.name):
return getattr(parent_default_instance, field.name), True
return getattr(parent_default_instance, field.name)

# Try grabbing default from dataclass field.
if (
Expand All @@ -623,7 +601,7 @@ def _get_dataclass_field_default(
# _types_, not just instances.
if type(default) is not type and dataclasses.is_dataclass(default):
_ensure_dataclass_instance_used_as_default_is_frozen(field, default)
return default, False
return default

# Populate default from `dataclasses.field(default_factory=...)`.
if field.default_factory is not dataclasses.MISSING and not (
Expand All @@ -637,10 +615,10 @@ def _get_dataclass_field_default(
# before this method is called.
dataclasses.is_dataclass(field.type) and field.default_factory is field.type
):
return field.default_factory(), False
return field.default_factory()

# Otherwise, no default.
return MISSING_NONPROP, False
return MISSING_NONPROP


if TYPE_CHECKING:
Expand All @@ -662,20 +640,20 @@ def _get_pydantic_v1_field_default(
):
# Populate default from some parent, eg `default=` in `tyro.cli()`.
if hasattr(parent_default_instance, name):
return getattr(parent_default_instance, name), True
return getattr(parent_default_instance, name)

if not field.required:
return field.get_default(), False
return field.get_default()

# Otherwise, no default.
return MISSING_NONPROP, False
return MISSING_NONPROP


def _get_pydantic_v2_field_default(
name: str,
field: pydantic.fields.FieldInfo,
parent_default_instance: Any,
) -> tuple[Any, bool]:
) -> Any:
"""Helper for getting the default instance for a Pydantic field."""

# Try grabbing default from parent instance.
Expand All @@ -685,10 +663,10 @@ def _get_pydantic_v2_field_default(
):
# Populate default from some parent, eg `default=` in `tyro.cli()`.
if hasattr(parent_default_instance, name):
return getattr(parent_default_instance, name), True
return getattr(parent_default_instance, name)

if not field.is_required():
return field.get_default(call_default_factory=True), False
return field.get_default(call_default_factory=True)

# Otherwise, no default.
return MISSING_NONPROP, False
return MISSING_NONPROP
Loading

0 comments on commit e7156a0

Please sign in to comment.