Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consistent behavior for implicit and explicit selection of optional subcommands #226

Merged
merged 7 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions examples/06_custom_constructors/04_struct_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,13 @@ def _(
if isinstance(type_info.default, Bounds):
# If the default value is a `Bounds` instance, we don't need to generate a constructor.
default = (type_info.default.bounds[0], type_info.default.bounds[1])
is_default_overridden = True
else:
# Otherwise, the default value is missing. We'll mark the child defaults as missing as well.
assert type_info.default in (
tyro.constructors.MISSING,
tyro.constructors.MISSING_NONPROP,
)
default = (tyro.MISSING, tyro.MISSING)
is_default_overridden = False

# If the rule applies, we return the constructor spec.
return tyro.constructors.StructConstructorSpec(
Expand All @@ -62,14 +60,12 @@ def _(
name="lower",
type=int,
default=default[0],
is_default_overridden=is_default_overridden,
helptext="Lower bound." "",
),
tyro.constructors.StructFieldSpec(
name="upper",
type=int,
default=default[1],
is_default_overridden=is_default_overridden,
helptext="Upper bound." "",
),
),
Expand Down
2 changes: 1 addition & 1 deletion src/tyro/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING

__version__ = "0.9.6"
__version__ = "0.9.7"


from . import conf as conf
Expand Down
9 changes: 0 additions & 9 deletions src/tyro/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ class FieldDefinition:
"""Full type, including runtime annotations."""
type_stripped: TypeForm[Any] | Callable
default: Any
# We need to record whether defaults are from default instances to
# determine if they should override the default in
# tyro.conf.subcommand(default=...).
is_default_from_default_instance: bool
helptext: Optional[str]
markers: Set[Any]
custom_constructor: bool
Expand Down Expand Up @@ -66,7 +62,6 @@ def from_field_spec(field_spec: StructFieldSpec) -> FieldDefinition:
name=field_spec.name,
typ=field_spec.type,
default=field_spec.default,
is_default_from_default_instance=field_spec.is_default_overridden,
helptext=field_spec.helptext,
call_argname_override=field_spec._call_argname,
)
Expand All @@ -76,7 +71,6 @@ def make(
name: str,
typ: Union[TypeForm[Any], Callable],
default: Any,
is_default_from_default_instance: bool,
helptext: Optional[str],
call_argname_override: Optional[Any] = None,
):
Expand Down Expand Up @@ -133,7 +127,6 @@ def make(
type=typ,
type_stripped=type_stripped,
default=default,
is_default_from_default_instance=is_default_from_default_instance,
helptext=helptext,
markers=set(markers),
custom_constructor=argconf.constructor_factory is not None,
Expand Down Expand Up @@ -243,7 +236,6 @@ def field_list_from_type_or_callable(
name="value",
typ=f,
default=default_instance,
is_default_from_default_instance=True,
helptext="",
)
],
Expand Down Expand Up @@ -371,7 +363,6 @@ def _field_list_from_function(
if default_instance in MISSING_AND_MISSING_NONPROP
else Annotated[(typ, _markers._OPTIONAL_GROUP)], # type: ignore
default=default if default is not param.empty else MISSING_NONPROP,
is_default_from_default_instance=False,
helptext=helptext,
)
)
Expand Down
3 changes: 1 addition & 2 deletions src/tyro/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,8 +553,7 @@ def from_field(

# If names match, borrow subcommand default from field default.
if default_name == subcommand_name and (
field.is_default_from_default_instance
or subcommand_config.default in _singleton.MISSING_AND_MISSING_NONPROP
field.default not in _singleton.MISSING_AND_MISSING_NONPROP
):
subcommand_config = dataclasses.replace(
subcommand_config, default=field.default
Expand Down
26 changes: 22 additions & 4 deletions src/tyro/conf/_confstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def subcommand(
.. code-block:: python

tyro.cli(
Union[NestedTypeA, NestedTypeB]
Union[StructTypeA, StructTypeB]
)

This will create two subcommands: `nested-type-a` and `nested-type-b`.
Expand All @@ -75,18 +75,36 @@ def subcommand(
tyro.cli(
Union[
Annotated[
NestedTypeA, subcommand("a", ...)
StructTypeA, subcommand("a", ...)
],
Annotated[
NestedTypeB, subcommand("b", ...)
StructTypeB, subcommand("b", ...)
],
]
)

If we have a default value both in the annotation and attached to the field
itself (eg, RHS of `=` within function or dataclass signature), the field
default will take precedence.

.. code-block:: python

# For the first subcommand, StructType(1) will be used as the default.
# The second subcommand, whose type is inconsistent with the field
# default, will be unaffected.
x: Union[
Annotated[
StructTypeA, subcommand(default=StructTypeA(0)
],
Annotated[
StructTypeB, subcommand(default=StructTypeB(0)
],
] = StructTypeA(1)

Arguments:
name: The name of the subcommand in the CLI.
default: A default value for the subcommand, for struct-like types. (eg
dataclasses)
dataclasses).
description: Description of this option to use in the helptext. Defaults to
docstring.
prefix_name: Whether to prefix the name of the subcommand based on where it
Expand Down
60 changes: 19 additions & 41 deletions src/tyro/constructors/_struct_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,14 @@ class StructFieldSpec:
"""The type of the field. Can be either a primitive or a nested struct type."""
default: Any
"""The default value of the field."""
is_default_overridden: bool = False
"""Whether the default value was overridden by the default instance. Should
be set to False if the default value was assigned by the field itself."""
helptext: str | None = None
"""Helpjext for the field."""
# TODO: it's theoretically possible to override the argname with `None`.
_call_argname: Any = None
"""Private: the name of the argument to pass to the callable. This is used
for dictionary types."""
is_default_overridden: None = None
"""Deprecated. No longer used."""


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -174,9 +173,7 @@ def dataclass_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
if is_flax_module and dc_field.name in ("name", "parent"):
continue

default, is_default_from_default_instance = _get_dataclass_field_default(
dc_field, info.default
)
default = _get_dataclass_field_default(dc_field, info.default)

# Try to get helptext from field metadata. This is also intended to be
# compatible with HuggingFace-style config objects.
Expand All @@ -194,7 +191,6 @@ def dataclass_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
name=dc_field.name,
type=dc_field.type,
default=default,
is_default_overridden=is_default_from_default_instance,
helptext=helptext,
)
)
Expand Down Expand Up @@ -222,10 +218,8 @@ def typeddict_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
cls, include_extras=True
).items():
typ_origin = get_origin(typ)
is_default_from_default_instance = False
if valid_default_instance and name in cast(dict, info.default):
default = cast(dict, info.default)[name]
is_default_from_default_instance = True
elif typ_origin is Required and total is False:
# Support total=False.
default = MISSING
Expand Down Expand Up @@ -261,7 +255,6 @@ def typeddict_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
name=name,
type=typ,
default=default,
is_default_overridden=is_default_from_default_instance,
helptext=_docstrings.get_field_docstring(cls, name),
)
)
Expand Down Expand Up @@ -296,11 +289,9 @@ def attrs_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
# Default handling.
name = attr_field.name
default = attr_field.default
is_default_from_default_instance = False
if info.default not in MISSING_AND_MISSING_NONPROP:
assert hasattr(info.default, name)
default = getattr(info.default, name)
is_default_from_default_instance = True
elif default is attr.NOTHING:
default = MISSING_NONPROP
elif isinstance(default, attr.Factory): # type: ignore
Expand All @@ -312,7 +303,6 @@ def attrs_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
name=name,
type=attr_field.type,
default=default,
is_default_overridden=is_default_from_default_instance,
helptext=_docstrings.get_field_docstring(info.type, name),
)
)
Expand Down Expand Up @@ -345,7 +335,6 @@ def dict_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
name=str(k) if not isinstance(k, enum.Enum) else k.name,
type=type(v),
default=v,
is_default_overridden=True,
helptext=None,
_call_argname=k,
)
Expand All @@ -368,13 +357,11 @@ def namedtuple_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
info.type, include_extras=True
).items():
default = field_defaults.get(name, MISSING_NONPROP)
is_default_from_default_instance = False

if info.default not in MISSING_AND_MISSING_NONPROP and hasattr(
info.default, name
):
default = getattr(info.default, name)
is_default_from_default_instance = True
elif info.default is MISSING:
default = MISSING

Expand All @@ -383,7 +370,6 @@ def namedtuple_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
name=name,
type=typ,
default=default,
is_default_overridden=is_default_from_default_instance,
helptext=_docstrings.get_field_docstring(info.type, name),
)
)
Expand Down Expand Up @@ -417,7 +403,6 @@ def sequence_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
name=str(i),
type=cast(type, contained_type),
default=default_i,
is_default_overridden=True,
helptext="",
)
)
Expand Down Expand Up @@ -461,7 +446,6 @@ def tuple_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
name=str(i),
type=child,
default=default_i,
is_default_overridden=True,
helptext="",
)
)
Expand Down Expand Up @@ -537,17 +521,14 @@ def pydantic_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
info.type, pd1_field.name
)

default, is_default_from_default_instance = (
_get_pydantic_v1_field_default(
pd1_field.name, pd1_field, info.default
)
default = _get_pydantic_v1_field_default(
pd1_field.name, pd1_field, info.default
)
field_list.append(
StructFieldSpec(
name=pd1_field.name,
type=hints[pd1_field.name],
default=default,
is_default_overridden=is_default_from_default_instance,
helptext=helptext,
)
)
Expand All @@ -558,9 +539,7 @@ def pydantic_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
if helptext is None:
helptext = _docstrings.get_field_docstring(info.type, name)

default, is_default_from_default_instance = (
_get_pydantic_v2_field_default(name, pd2_field, info.default)
)
default = _get_pydantic_v2_field_default(name, pd2_field, info.default)
field_list.append(
StructFieldSpec(
name=name,
Expand All @@ -572,7 +551,6 @@ def pydantic_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
else pd2_field.annotation
),
default=default,
is_default_overridden=is_default_from_default_instance,
helptext=helptext,
)
)
Expand All @@ -597,12 +575,12 @@ def _ensure_dataclass_instance_used_as_default_is_frozen(

def _get_dataclass_field_default(
field: dataclasses.Field, parent_default_instance: Any
) -> tuple[Any, bool]:
) -> Any:
"""Helper for getting the default instance for a dataclass field."""
# If the dataclass's parent is explicitly marked MISSING, mark this field as missing
# as well.
if parent_default_instance is MISSING:
return MISSING, False
return MISSING

# Try grabbing default from parent instance.
if (
Expand All @@ -611,7 +589,7 @@ def _get_dataclass_field_default(
):
# Populate default from some parent, eg `default=` in `tyro.cli()`.
if hasattr(parent_default_instance, field.name):
return getattr(parent_default_instance, field.name), True
return getattr(parent_default_instance, field.name)

# Try grabbing default from dataclass field.
if (
Expand All @@ -623,7 +601,7 @@ def _get_dataclass_field_default(
# _types_, not just instances.
if type(default) is not type and dataclasses.is_dataclass(default):
_ensure_dataclass_instance_used_as_default_is_frozen(field, default)
return default, False
return default

# Populate default from `dataclasses.field(default_factory=...)`.
if field.default_factory is not dataclasses.MISSING and not (
Expand All @@ -637,10 +615,10 @@ def _get_dataclass_field_default(
# before this method is called.
dataclasses.is_dataclass(field.type) and field.default_factory is field.type
):
return field.default_factory(), False
return field.default_factory()

# Otherwise, no default.
return MISSING_NONPROP, False
return MISSING_NONPROP


if TYPE_CHECKING:
Expand All @@ -662,20 +640,20 @@ def _get_pydantic_v1_field_default(
):
# Populate default from some parent, eg `default=` in `tyro.cli()`.
if hasattr(parent_default_instance, name):
return getattr(parent_default_instance, name), True
return getattr(parent_default_instance, name)

if not field.required:
return field.get_default(), False
return field.get_default()

# Otherwise, no default.
return MISSING_NONPROP, False
return MISSING_NONPROP


def _get_pydantic_v2_field_default(
name: str,
field: pydantic.fields.FieldInfo,
parent_default_instance: Any,
) -> tuple[Any, bool]:
) -> Any:
"""Helper for getting the default instance for a Pydantic field."""

# Try grabbing default from parent instance.
Expand All @@ -685,10 +663,10 @@ def _get_pydantic_v2_field_default(
):
# Populate default from some parent, eg `default=` in `tyro.cli()`.
if hasattr(parent_default_instance, name):
return getattr(parent_default_instance, name), True
return getattr(parent_default_instance, name)

if not field.is_required():
return field.get_default(call_default_factory=True), False
return field.get_default(call_default_factory=True)

# Otherwise, no default.
return MISSING_NONPROP, False
return MISSING_NONPROP
Loading
Loading