From 73e24292d80364ee82e16da5ec9810c727e0db5f Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 7 Jan 2025 02:43:15 -0800 Subject: [PATCH 1/7] Consistent behavior for implicit and explicit selection of default subcommands, docs --- src/tyro/_parsers.py | 5 ++-- src/tyro/conf/_confstruct.py | 26 ++++++++++++++++--- .../test_base_configs_nested_exclude_py313.py | 5 ++-- 3 files changed, 28 insertions(+), 8 deletions(-) diff --git a/src/tyro/_parsers.py b/src/tyro/_parsers.py index 08549be0..0b5a33d6 100644 --- a/src/tyro/_parsers.py +++ b/src/tyro/_parsers.py @@ -553,8 +553,9 @@ def from_field( # If names match, borrow subcommand default from field default. if default_name == subcommand_name and ( - field.is_default_from_default_instance - or subcommand_config.default in _singleton.MISSING_AND_MISSING_NONPROP + # field.is_default_from_default_instance + # or subcommand_config.default in _singleton.MISSING_AND_MISSING_NONPROP + field.default not in _singleton.MISSING_AND_MISSING_NONPROP ): subcommand_config = dataclasses.replace( subcommand_config, default=field.default diff --git a/src/tyro/conf/_confstruct.py b/src/tyro/conf/_confstruct.py index d7043ae0..8f5395f8 100644 --- a/src/tyro/conf/_confstruct.py +++ b/src/tyro/conf/_confstruct.py @@ -61,7 +61,7 @@ def subcommand( .. code-block:: python tyro.cli( - Union[NestedTypeA, NestedTypeB] + Union[StructTypeA, StructTypeB] ) This will create two subcommands: `nested-type-a` and `nested-type-b`. @@ -75,18 +75,36 @@ def subcommand( tyro.cli( Union[ Annotated[ - NestedTypeA, subcommand("a", ...) + StructTypeA, subcommand("a", ...) ], Annotated[ - NestedTypeB, subcommand("b", ...) + StructTypeB, subcommand("b", ...) ], ] ) + If we have a default value both in the annotation and attached to the field + itself (eg, RHS of `=` within function or dataclass signature), the field + default will take precedence. + + .. code-block:: python + + # For the first subcommand, StructType(1) will be used as the default. + # The second subcommand, whose type is inconsistent with the field + # default, will be unaffected. + x: Union[ + Annotated[ + StructTypeA, subcommand(default=StructTypeA(0) + ], + Annotated[ + StructTypeB, subcommand(default=StructTypeB(0) + ], + ] = StructTypeA(1) + Arguments: name: The name of the subcommand in the CLI. default: A default value for the subcommand, for struct-like types. (eg - dataclasses) + dataclasses). description: Description of this option to use in the helptext. Defaults to docstring. prefix_name: Whether to prefix the name of the subcommand based on where it diff --git a/tests/test_base_configs_nested_exclude_py313.py b/tests/test_base_configs_nested_exclude_py313.py index fb738e05..959a81d4 100644 --- a/tests/test_base_configs_nested_exclude_py313.py +++ b/tests/test_base_configs_nested_exclude_py313.py @@ -117,7 +117,8 @@ class BaseConfig: experiment_config: AnnotatedExperimentParserUnion # The experiment configuration. - data_config: AnnotatedDataParserUnion = DataConfig() + # The default should get matched to small-data. + data_config: AnnotatedDataParserUnion = DataConfig(test=0) def test_base_configs_nested() -> None: @@ -158,7 +159,7 @@ def main(cfg: BaseConfig) -> BaseConfig: seed=0, activation=nn.ReLU, ), - DataConfig(2221), + DataConfig(0), ) assert tyro.cli( main, From 412ea21fff2a3f4b38ded5ca3f0e5a185dd0cf48 Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 7 Jan 2025 02:48:36 -0800 Subject: [PATCH 2/7] One more test --- tests/test_conf.py | 55 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 53 insertions(+), 2 deletions(-) diff --git a/tests/test_conf.py b/tests/test_conf.py index 524d8503..d1f701c8 100644 --- a/tests/test_conf.py +++ b/tests/test_conf.py @@ -8,10 +8,10 @@ from typing import Any, Dict, Generic, List, Tuple, Type, TypeVar, Union import pytest -from helptext_utils import get_helptext_with_checks +import tyro from typing_extensions import Annotated, TypedDict -import tyro +from helptext_utils import get_helptext_with_checks def test_suppress_subcommand() -> None: @@ -1706,3 +1706,54 @@ class Trunk: assert tyro.cli( Trunk, default=Trunk(Branch2(x=tyro.MISSING)), args=["branch:branch1"] ) == Trunk(Branch1()) + + +def test_default_subcommand_consistency() -> None: + """https://github.com/brentyi/tyro/issues/221""" + + @dataclasses.dataclass(frozen=True) + class OptimizerConfig: + lr: float = 1e-1 + + @dataclasses.dataclass(frozen=True) + class AdamConfig(OptimizerConfig): + adam_foo: float = 1.0 + + @dataclasses.dataclass(frozen=True) + class SGDConfig(OptimizerConfig): + sgd_foo: float = 1.0 + + def _constructor() -> Any: + cfgs = [ + Annotated[SGDConfig, tyro.conf.subcommand(name="sgd", default=SGDConfig())], + Annotated[ + AdamConfig, tyro.conf.subcommand(name="adam", default=AdamConfig()) + ], + ] + return Union.__getitem__(tuple(cfgs)) # type: ignore + + CLIOptimizer = Annotated[ + OptimizerConfig, + tyro.conf.arg(constructor_factory=_constructor), + ] + + @dataclasses.dataclass + class Config: + optimizer: CLIOptimizer = AdamConfig(adam_foo=0.5) # type: ignore + foo: int = 1 + bar: str = "abc" + + assert tyro.cli(Config, args=[]) == Config() + assert ( + tyro.cli( + Config, + config=(tyro.conf.ConsolidateSubcommandArgs,), + args=["optimizer:adam"], + ) + == Config() + ) + assert ( + tyro.cli(Config, config=(tyro.conf.ConsolidateSubcommandArgs,), args=[]) + == Config() + ) + assert tyro.cli(Config, args=["optimizer:adam"]) == Config() From b75547f4fb4a8ba61dd893e56a96fbd7f052551f Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 7 Jan 2025 02:48:49 -0800 Subject: [PATCH 3/7] Generate tests --- ..._configs_nested_exclude_py313_generated.py | 5 +- .../test_conf_generated.py | 62 +++++++++++++++++-- .../test_nested_generated.py | 47 ++++++++++++++ 3 files changed, 108 insertions(+), 6 deletions(-) diff --git a/tests/test_py311_generated/test_base_configs_nested_exclude_py313_generated.py b/tests/test_py311_generated/test_base_configs_nested_exclude_py313_generated.py index 6b1ff72c..13b5f683 100644 --- a/tests/test_py311_generated/test_base_configs_nested_exclude_py313_generated.py +++ b/tests/test_py311_generated/test_base_configs_nested_exclude_py313_generated.py @@ -116,7 +116,8 @@ class BaseConfig: experiment_config: AnnotatedExperimentParserUnion # The experiment configuration. - data_config: AnnotatedDataParserUnion = DataConfig() + # The default should get matched to small-data. + data_config: AnnotatedDataParserUnion = DataConfig(test=0) def test_base_configs_nested() -> None: @@ -157,7 +158,7 @@ def main(cfg: BaseConfig) -> BaseConfig: seed=0, activation=nn.ReLU, ), - DataConfig(2221), + DataConfig(0), ) assert tyro.cli( main, diff --git a/tests/test_py311_generated/test_conf_generated.py b/tests/test_py311_generated/test_conf_generated.py index 7014f52c..28b8c2de 100644 --- a/tests/test_py311_generated/test_conf_generated.py +++ b/tests/test_py311_generated/test_conf_generated.py @@ -1625,7 +1625,7 @@ class AdamConfig(OptimizerConfig): class SGDConfig(OptimizerConfig): sgd_foo: float = 1.0 - def _constructor() -> type[OptimizerConfig]: + def _constructor() -> Type[OptimizerConfig]: cfgs = [ Annotated[AdamConfig, tyro.conf.subcommand(name="adam")], Annotated[SGDConfig, tyro.conf.subcommand(name="sgd")], @@ -1637,7 +1637,8 @@ def _constructor() -> type[OptimizerConfig]: class Config1: x: int optimizer: Annotated[ - AdamConfig | SGDConfig, tyro.conf.arg(constructor_factory=_constructor) + AdamConfig | SGDConfig, + tyro.conf.arg(constructor_factory=_constructor), ] = AdamConfig() with pytest.raises(SystemExit): @@ -1647,7 +1648,8 @@ class Config1: @dataclasses.dataclass class Config2: optimizer: Annotated[ - AdamConfig | SGDConfig, tyro.conf.arg(constructor_factory=_constructor) + AdamConfig | SGDConfig, + tyro.conf.arg(constructor_factory=_constructor), ] with pytest.raises(SystemExit): @@ -1658,7 +1660,8 @@ class Config2: class Config3: x: int = 3 optimizer: Annotated[ - AdamConfig | SGDConfig, tyro.conf.arg(constructor_factory=_constructor) + AdamConfig | SGDConfig, + tyro.conf.arg(constructor_factory=_constructor), ] = AdamConfig() assert ( @@ -1709,3 +1712,54 @@ class Trunk: assert tyro.cli( Trunk, default=Trunk(Branch2(x=tyro.MISSING)), args=["branch:branch1"] ) == Trunk(Branch1()) + + +def test_default_subcommand_consistency() -> None: + """https://github.com/brentyi/tyro/issues/221""" + + @dataclasses.dataclass(frozen=True) + class OptimizerConfig: + lr: float = 1e-1 + + @dataclasses.dataclass(frozen=True) + class AdamConfig(OptimizerConfig): + adam_foo: float = 1.0 + + @dataclasses.dataclass(frozen=True) + class SGDConfig(OptimizerConfig): + sgd_foo: float = 1.0 + + def _constructor() -> Any: + cfgs = [ + Annotated[SGDConfig, tyro.conf.subcommand(name="sgd", default=SGDConfig())], + Annotated[ + AdamConfig, tyro.conf.subcommand(name="adam", default=AdamConfig()) + ], + ] + return Union.__getitem__(tuple(cfgs)) # type: ignore + + CLIOptimizer = Annotated[ + OptimizerConfig, + tyro.conf.arg(constructor_factory=_constructor), + ] + + @dataclasses.dataclass + class Config: + optimizer: CLIOptimizer = AdamConfig(adam_foo=0.5) # type: ignore + foo: int = 1 + bar: str = "abc" + + assert tyro.cli(Config, args=[]) == Config() + assert ( + tyro.cli( + Config, + config=(tyro.conf.ConsolidateSubcommandArgs,), + args=["optimizer:adam"], + ) + == Config() + ) + assert ( + tyro.cli(Config, config=(tyro.conf.ConsolidateSubcommandArgs,), args=[]) + == Config() + ) + assert tyro.cli(Config, args=["optimizer:adam"]) == Config() diff --git a/tests/test_py311_generated/test_nested_generated.py b/tests/test_py311_generated/test_nested_generated.py index b4073661..918d79e3 100644 --- a/tests/test_py311_generated/test_nested_generated.py +++ b/tests/test_py311_generated/test_nested_generated.py @@ -9,6 +9,7 @@ Optional, Tuple, TypeVar, + Union, ) import pytest @@ -1342,3 +1343,49 @@ def main(config: Config = ("hello", 5)) -> Any: # type: ignore assert tyro.cli( main, args="config:config --config.name world --config.age 27".split(" ") ) == Config(name="world", age=27) + + +def test_subcommand_default_with_conf_annotation() -> None: + """Adapted from @mirceamironenco. + + https://github.com/brentyi/tyro/issues/221#issuecomment-2572850582 + """ + + @dataclasses.dataclass(frozen=True) + class OptimizerConfig: + lr: float = 1e-1 + + @dataclasses.dataclass(frozen=True) + class AdamConfig(OptimizerConfig): + adam_foo: float = 1.0 + + @dataclasses.dataclass(frozen=True) + class SGDConfig(OptimizerConfig): + sgd_foo: float = 1.0 + + def _constructor() -> Any: + cfgs = [ + Annotated[SGDConfig, tyro.conf.subcommand(name="sgd")], + Annotated[AdamConfig, tyro.conf.subcommand(name="adam")], + ] + return Union.__getitem__(tuple(cfgs)) # type: ignore + + @dataclasses.dataclass(frozen=True) + class Config1: + optimizer: Annotated[ + OptimizerConfig, tyro.conf.arg(constructor_factory=_constructor) + ] = AdamConfig() + foo: int = 1 + bar: str = "abc" + + assert "(default: optimizer:adam)" in get_helptext_with_checks(Config1) + + @dataclasses.dataclass(frozen=True) + class Config2: + optimizer: Annotated[ + OptimizerConfig, tyro.conf.arg(constructor_factory=_constructor) + ] = SGDConfig() + foo: int = 1 + bar: str = "abc" + + assert "(default: optimizer:sgd)" in get_helptext_with_checks(Config2) From 4fc695bbc0a01d2a1e0ae37ff5eaf00ddc36c44c Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 7 Jan 2025 02:49:00 -0800 Subject: [PATCH 4/7] Bump version --- src/tyro/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tyro/__init__.py b/src/tyro/__init__.py index 9c5b2941..ab80708f 100644 --- a/src/tyro/__init__.py +++ b/src/tyro/__init__.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING -__version__ = "0.9.6" +__version__ = "0.9.7" from . import conf as conf From d0d92fe2a027750c3b1f9d1a16a1a59a8ac8e82b Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 7 Jan 2025 02:59:55 -0800 Subject: [PATCH 5/7] Housekeeping --- .../04_struct_registry.py | 4 -- src/tyro/_fields.py | 9 --- src/tyro/_parsers.py | 2 - src/tyro/constructors/_struct_spec.py | 60 ++++++------------- 4 files changed, 19 insertions(+), 56 deletions(-) diff --git a/examples/06_custom_constructors/04_struct_registry.py b/examples/06_custom_constructors/04_struct_registry.py index b5b13f5c..c9acb3f4 100644 --- a/examples/06_custom_constructors/04_struct_registry.py +++ b/examples/06_custom_constructors/04_struct_registry.py @@ -43,7 +43,6 @@ def _( if isinstance(type_info.default, Bounds): # If the default value is a `Bounds` instance, we don't need to generate a constructor. default = (type_info.default.bounds[0], type_info.default.bounds[1]) - is_default_overridden = True else: # Otherwise, the default value is missing. We'll mark the child defaults as missing as well. assert type_info.default in ( @@ -51,7 +50,6 @@ def _( tyro.constructors.MISSING_NONPROP, ) default = (tyro.MISSING, tyro.MISSING) - is_default_overridden = False # If the rule applies, we return the constructor spec. return tyro.constructors.StructConstructorSpec( @@ -62,14 +60,12 @@ def _( name="lower", type=int, default=default[0], - is_default_overridden=is_default_overridden, helptext="Lower bound." "", ), tyro.constructors.StructFieldSpec( name="upper", type=int, default=default[1], - is_default_overridden=is_default_overridden, helptext="Upper bound." "", ), ), diff --git a/src/tyro/_fields.py b/src/tyro/_fields.py index aebc46e0..38005ef9 100644 --- a/src/tyro/_fields.py +++ b/src/tyro/_fields.py @@ -35,10 +35,6 @@ class FieldDefinition: """Full type, including runtime annotations.""" type_stripped: TypeForm[Any] | Callable default: Any - # We need to record whether defaults are from default instances to - # determine if they should override the default in - # tyro.conf.subcommand(default=...). - is_default_from_default_instance: bool helptext: Optional[str] markers: Set[Any] custom_constructor: bool @@ -66,7 +62,6 @@ def from_field_spec(field_spec: StructFieldSpec) -> FieldDefinition: name=field_spec.name, typ=field_spec.type, default=field_spec.default, - is_default_from_default_instance=field_spec.is_default_overridden, helptext=field_spec.helptext, call_argname_override=field_spec._call_argname, ) @@ -76,7 +71,6 @@ def make( name: str, typ: Union[TypeForm[Any], Callable], default: Any, - is_default_from_default_instance: bool, helptext: Optional[str], call_argname_override: Optional[Any] = None, ): @@ -133,7 +127,6 @@ def make( type=typ, type_stripped=type_stripped, default=default, - is_default_from_default_instance=is_default_from_default_instance, helptext=helptext, markers=set(markers), custom_constructor=argconf.constructor_factory is not None, @@ -243,7 +236,6 @@ def field_list_from_type_or_callable( name="value", typ=f, default=default_instance, - is_default_from_default_instance=True, helptext="", ) ], @@ -371,7 +363,6 @@ def _field_list_from_function( if default_instance in MISSING_AND_MISSING_NONPROP else Annotated[(typ, _markers._OPTIONAL_GROUP)], # type: ignore default=default if default is not param.empty else MISSING_NONPROP, - is_default_from_default_instance=False, helptext=helptext, ) ) diff --git a/src/tyro/_parsers.py b/src/tyro/_parsers.py index 0b5a33d6..d03edc3f 100644 --- a/src/tyro/_parsers.py +++ b/src/tyro/_parsers.py @@ -553,8 +553,6 @@ def from_field( # If names match, borrow subcommand default from field default. if default_name == subcommand_name and ( - # field.is_default_from_default_instance - # or subcommand_config.default in _singleton.MISSING_AND_MISSING_NONPROP field.default not in _singleton.MISSING_AND_MISSING_NONPROP ): subcommand_config = dataclasses.replace( diff --git a/src/tyro/constructors/_struct_spec.py b/src/tyro/constructors/_struct_spec.py index 411674d8..503df115 100644 --- a/src/tyro/constructors/_struct_spec.py +++ b/src/tyro/constructors/_struct_spec.py @@ -56,15 +56,14 @@ class StructFieldSpec: """The type of the field. Can be either a primitive or a nested struct type.""" default: Any """The default value of the field.""" - is_default_overridden: bool = False - """Whether the default value was overridden by the default instance. Should - be set to False if the default value was assigned by the field itself.""" helptext: str | None = None """Helpjext for the field.""" # TODO: it's theoretically possible to override the argname with `None`. _call_argname: Any = None """Private: the name of the argument to pass to the callable. This is used for dictionary types.""" + is_default_overriden: None = None + """*Deprecated.*""" @dataclasses.dataclass(frozen=True) @@ -174,9 +173,7 @@ def dataclass_rule(info: StructTypeInfo) -> StructConstructorSpec | None: if is_flax_module and dc_field.name in ("name", "parent"): continue - default, is_default_from_default_instance = _get_dataclass_field_default( - dc_field, info.default - ) + default = _get_dataclass_field_default(dc_field, info.default) # Try to get helptext from field metadata. This is also intended to be # compatible with HuggingFace-style config objects. @@ -194,7 +191,6 @@ def dataclass_rule(info: StructTypeInfo) -> StructConstructorSpec | None: name=dc_field.name, type=dc_field.type, default=default, - is_default_overridden=is_default_from_default_instance, helptext=helptext, ) ) @@ -222,10 +218,8 @@ def typeddict_rule(info: StructTypeInfo) -> StructConstructorSpec | None: cls, include_extras=True ).items(): typ_origin = get_origin(typ) - is_default_from_default_instance = False if valid_default_instance and name in cast(dict, info.default): default = cast(dict, info.default)[name] - is_default_from_default_instance = True elif typ_origin is Required and total is False: # Support total=False. default = MISSING @@ -261,7 +255,6 @@ def typeddict_rule(info: StructTypeInfo) -> StructConstructorSpec | None: name=name, type=typ, default=default, - is_default_overridden=is_default_from_default_instance, helptext=_docstrings.get_field_docstring(cls, name), ) ) @@ -296,11 +289,9 @@ def attrs_rule(info: StructTypeInfo) -> StructConstructorSpec | None: # Default handling. name = attr_field.name default = attr_field.default - is_default_from_default_instance = False if info.default not in MISSING_AND_MISSING_NONPROP: assert hasattr(info.default, name) default = getattr(info.default, name) - is_default_from_default_instance = True elif default is attr.NOTHING: default = MISSING_NONPROP elif isinstance(default, attr.Factory): # type: ignore @@ -312,7 +303,6 @@ def attrs_rule(info: StructTypeInfo) -> StructConstructorSpec | None: name=name, type=attr_field.type, default=default, - is_default_overridden=is_default_from_default_instance, helptext=_docstrings.get_field_docstring(info.type, name), ) ) @@ -345,7 +335,6 @@ def dict_rule(info: StructTypeInfo) -> StructConstructorSpec | None: name=str(k) if not isinstance(k, enum.Enum) else k.name, type=type(v), default=v, - is_default_overridden=True, helptext=None, _call_argname=k, ) @@ -368,13 +357,11 @@ def namedtuple_rule(info: StructTypeInfo) -> StructConstructorSpec | None: info.type, include_extras=True ).items(): default = field_defaults.get(name, MISSING_NONPROP) - is_default_from_default_instance = False if info.default not in MISSING_AND_MISSING_NONPROP and hasattr( info.default, name ): default = getattr(info.default, name) - is_default_from_default_instance = True elif info.default is MISSING: default = MISSING @@ -383,7 +370,6 @@ def namedtuple_rule(info: StructTypeInfo) -> StructConstructorSpec | None: name=name, type=typ, default=default, - is_default_overridden=is_default_from_default_instance, helptext=_docstrings.get_field_docstring(info.type, name), ) ) @@ -417,7 +403,6 @@ def sequence_rule(info: StructTypeInfo) -> StructConstructorSpec | None: name=str(i), type=cast(type, contained_type), default=default_i, - is_default_overridden=True, helptext="", ) ) @@ -461,7 +446,6 @@ def tuple_rule(info: StructTypeInfo) -> StructConstructorSpec | None: name=str(i), type=child, default=default_i, - is_default_overridden=True, helptext="", ) ) @@ -537,17 +521,14 @@ def pydantic_rule(info: StructTypeInfo) -> StructConstructorSpec | None: info.type, pd1_field.name ) - default, is_default_from_default_instance = ( - _get_pydantic_v1_field_default( - pd1_field.name, pd1_field, info.default - ) + default = _get_pydantic_v1_field_default( + pd1_field.name, pd1_field, info.default ) field_list.append( StructFieldSpec( name=pd1_field.name, type=hints[pd1_field.name], default=default, - is_default_overridden=is_default_from_default_instance, helptext=helptext, ) ) @@ -558,9 +539,7 @@ def pydantic_rule(info: StructTypeInfo) -> StructConstructorSpec | None: if helptext is None: helptext = _docstrings.get_field_docstring(info.type, name) - default, is_default_from_default_instance = ( - _get_pydantic_v2_field_default(name, pd2_field, info.default) - ) + default = _get_pydantic_v2_field_default(name, pd2_field, info.default) field_list.append( StructFieldSpec( name=name, @@ -572,7 +551,6 @@ def pydantic_rule(info: StructTypeInfo) -> StructConstructorSpec | None: else pd2_field.annotation ), default=default, - is_default_overridden=is_default_from_default_instance, helptext=helptext, ) ) @@ -597,12 +575,12 @@ def _ensure_dataclass_instance_used_as_default_is_frozen( def _get_dataclass_field_default( field: dataclasses.Field, parent_default_instance: Any -) -> tuple[Any, bool]: +) -> Any: """Helper for getting the default instance for a dataclass field.""" # If the dataclass's parent is explicitly marked MISSING, mark this field as missing # as well. if parent_default_instance is MISSING: - return MISSING, False + return MISSING # Try grabbing default from parent instance. if ( @@ -611,7 +589,7 @@ def _get_dataclass_field_default( ): # Populate default from some parent, eg `default=` in `tyro.cli()`. if hasattr(parent_default_instance, field.name): - return getattr(parent_default_instance, field.name), True + return getattr(parent_default_instance, field.name) # Try grabbing default from dataclass field. if ( @@ -623,7 +601,7 @@ def _get_dataclass_field_default( # _types_, not just instances. if type(default) is not type and dataclasses.is_dataclass(default): _ensure_dataclass_instance_used_as_default_is_frozen(field, default) - return default, False + return default # Populate default from `dataclasses.field(default_factory=...)`. if field.default_factory is not dataclasses.MISSING and not ( @@ -637,10 +615,10 @@ def _get_dataclass_field_default( # before this method is called. dataclasses.is_dataclass(field.type) and field.default_factory is field.type ): - return field.default_factory(), False + return field.default_factory() # Otherwise, no default. - return MISSING_NONPROP, False + return MISSING_NONPROP if TYPE_CHECKING: @@ -662,20 +640,20 @@ def _get_pydantic_v1_field_default( ): # Populate default from some parent, eg `default=` in `tyro.cli()`. if hasattr(parent_default_instance, name): - return getattr(parent_default_instance, name), True + return getattr(parent_default_instance, name) if not field.required: - return field.get_default(), False + return field.get_default() # Otherwise, no default. - return MISSING_NONPROP, False + return MISSING_NONPROP def _get_pydantic_v2_field_default( name: str, field: pydantic.fields.FieldInfo, parent_default_instance: Any, -) -> tuple[Any, bool]: +) -> Any: """Helper for getting the default instance for a Pydantic field.""" # Try grabbing default from parent instance. @@ -685,10 +663,10 @@ def _get_pydantic_v2_field_default( ): # Populate default from some parent, eg `default=` in `tyro.cli()`. if hasattr(parent_default_instance, name): - return getattr(parent_default_instance, name), True + return getattr(parent_default_instance, name) if not field.is_required(): - return field.get_default(call_default_factory=True), False + return field.get_default(call_default_factory=True) # Otherwise, no default. - return MISSING_NONPROP, False + return MISSING_NONPROP From 0e366401b9af113ffa1f5d06fee19b059ef11425 Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 7 Jan 2025 03:00:22 -0800 Subject: [PATCH 6/7] ruff --- tests/test_conf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_conf.py b/tests/test_conf.py index d1f701c8..9e34915d 100644 --- a/tests/test_conf.py +++ b/tests/test_conf.py @@ -8,10 +8,10 @@ from typing import Any, Dict, Generic, List, Tuple, Type, TypeVar, Union import pytest -import tyro +from helptext_utils import get_helptext_with_checks from typing_extensions import Annotated, TypedDict -from helptext_utils import get_helptext_with_checks +import tyro def test_suppress_subcommand() -> None: From 20e48b541fb4503afa36eb720fa2da170bb0b1f8 Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 7 Jan 2025 03:01:32 -0800 Subject: [PATCH 7/7] spelling --- src/tyro/constructors/_struct_spec.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tyro/constructors/_struct_spec.py b/src/tyro/constructors/_struct_spec.py index 503df115..a4663bfe 100644 --- a/src/tyro/constructors/_struct_spec.py +++ b/src/tyro/constructors/_struct_spec.py @@ -62,8 +62,8 @@ class StructFieldSpec: _call_argname: Any = None """Private: the name of the argument to pass to the callable. This is used for dictionary types.""" - is_default_overriden: None = None - """*Deprecated.*""" + is_default_overridden: None = None + """Deprecated. No longer used.""" @dataclasses.dataclass(frozen=True)