Skip to content

Commit 4d18b8e

Browse files
authored
Fix attribute inheritance regression (introduced in 0.9.9) (#240)
* Fix attribute inheritance regression * Add test * Workaround for Python 3.8 __annotations__ behavior * Remove unnecessary filter * Additional test * ruff
1 parent 690e04e commit 4d18b8e

File tree

3 files changed

+194
-1
lines changed

3 files changed

+194
-1
lines changed

src/tyro/_resolver.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,13 @@ def get_hints_for_bound_method(cls) -> Dict[str, Any]:
673673
for x, t in _get_type_hints_backported_syntax(
674674
obj, include_extras=include_extras
675675
).items()
676-
if x in obj.__annotations__
676+
# Only include type hints that are explicitly defined in this class.
677+
#
678+
# Why `cls.__dict__.__annotations__` instead of `cls.__annotations__`?
679+
# Because in Python 3.8 and earlier, `cls.__annotations__`
680+
# recursively merges parent class annotations.
681+
# See this issue: https://github.com/python/cpython/issues/99535
682+
if x in obj.__dict__.get("__annotations__", {})
677683
}
678684

679685
# We need to recurse into base classes in order to correctly resolve superclass parameters.
@@ -689,6 +695,11 @@ def get_hints_for_bound_method(cls) -> Dict[str, Any]:
689695
{
690696
x: TypeParamResolver.concretize_type_params(t)
691697
for x, t in base_hints.items()
698+
# Include type hints that are not assigned earlier in the MRO.
699+
#
700+
# This needs to be recursive (include parents of parents),
701+
# so we shouldn't filter by local __annotations__.
702+
if x not in out
692703
}
693704
)
694705

tests/test_conf.py

+91
Original file line numberDiff line numberDiff line change
@@ -1794,3 +1794,94 @@ def main(
17941794
with pytest.raises(SystemExit):
17951795
# ConfigB has a required argument.
17961796
assert tyro.cli(main, args=["x:config-b"])
1797+
1798+
1799+
_dataset_map = {
1800+
"alpaca": "tatsu-lab/alpaca",
1801+
"alpaca_clean": "yahma/alpaca-cleaned",
1802+
"alpaca_gpt4": "vicgalle/alpaca-gpt4",
1803+
}
1804+
_inv_dataset_map = {value: key for key, value in _dataset_map.items()}
1805+
_datasets = list(_dataset_map.keys())
1806+
1807+
HFDataset = Annotated[
1808+
str,
1809+
tyro.constructors.PrimitiveConstructorSpec(
1810+
nargs=1,
1811+
metavar="{" + ",".join(_datasets) + "}",
1812+
instance_from_str=lambda args: _dataset_map[args[0]],
1813+
is_instance=lambda instance: isinstance(instance, str)
1814+
and instance in _inv_dataset_map,
1815+
str_from_instance=lambda instance: [_inv_dataset_map[instance]],
1816+
choices=tuple(_datasets),
1817+
),
1818+
tyro.conf.arg(
1819+
help_behavior_hint=lambda df: f"(default: {df}, run datasets.py for full options)"
1820+
),
1821+
]
1822+
1823+
1824+
def test_annotated_attribute_inheritance() -> None:
1825+
"""From @mirceamironenco.
1826+
1827+
https://github.com/brentyi/tyro/issues/239"""
1828+
1829+
@dataclasses.dataclass(frozen=True)
1830+
class TrainConfig:
1831+
dataset: str = "vicgalle/alpaca-gpt4"
1832+
1833+
@dataclasses.dataclass(frozen=True)
1834+
class CLITrainerConfig(TrainConfig):
1835+
dataset: HFDataset = "vicgalle/alpaca-gpt4"
1836+
1837+
assert "{alpaca,alpaca_clean,alpaca_gpt4}" in get_helptext_with_checks(
1838+
CLITrainerConfig
1839+
)
1840+
assert (
1841+
"default: alpaca_gpt4, run datasets.py for full options"
1842+
in get_helptext_with_checks(CLITrainerConfig)
1843+
)
1844+
1845+
1846+
@dataclasses.dataclass(frozen=True)
1847+
class OptimizerConfig:
1848+
lr: float = 1e-1
1849+
1850+
1851+
@dataclasses.dataclass(frozen=True)
1852+
class AdamConfig(OptimizerConfig):
1853+
adam_foo: float = 1.0
1854+
1855+
1856+
@dataclasses.dataclass(frozen=True)
1857+
class SGDConfig(OptimizerConfig):
1858+
sgd_foo: float = 1.0
1859+
1860+
1861+
@dataclasses.dataclass
1862+
class TrainConfig:
1863+
optimizer: OptimizerConfig = AdamConfig()
1864+
1865+
1866+
def _dummy_constructor() -> Type[OptimizerConfig]:
1867+
return Union[AdamConfig, SGDConfig] # type: ignore
1868+
1869+
1870+
CLIOptimizerConfig = Annotated[
1871+
OptimizerConfig,
1872+
tyro.conf.arg(constructor_factory=_dummy_constructor),
1873+
]
1874+
1875+
1876+
def test_attribute_inheritance_2() -> None:
1877+
"""From @mirceamironenco.
1878+
1879+
https://github.com/brentyi/tyro/issues/239"""
1880+
1881+
@dataclasses.dataclass
1882+
class CLITrainerConfig(TrainConfig):
1883+
optimizer: CLIOptimizerConfig = SGDConfig()
1884+
1885+
assert "[{optimizer:adam-config,optimizer:sgd-config}]" in get_helptext_with_checks(
1886+
CLITrainerConfig
1887+
)

tests/test_py311_generated/test_conf_generated.py

+91
Original file line numberDiff line numberDiff line change
@@ -1800,3 +1800,94 @@ def main(
18001800
with pytest.raises(SystemExit):
18011801
# ConfigB has a required argument.
18021802
assert tyro.cli(main, args=["x:config-b"])
1803+
1804+
1805+
_dataset_map = {
1806+
"alpaca": "tatsu-lab/alpaca",
1807+
"alpaca_clean": "yahma/alpaca-cleaned",
1808+
"alpaca_gpt4": "vicgalle/alpaca-gpt4",
1809+
}
1810+
_inv_dataset_map = {value: key for key, value in _dataset_map.items()}
1811+
_datasets = list(_dataset_map.keys())
1812+
1813+
HFDataset = Annotated[
1814+
str,
1815+
tyro.constructors.PrimitiveConstructorSpec(
1816+
nargs=1,
1817+
metavar="{" + ",".join(_datasets) + "}",
1818+
instance_from_str=lambda args: _dataset_map[args[0]],
1819+
is_instance=lambda instance: isinstance(instance, str)
1820+
and instance in _inv_dataset_map,
1821+
str_from_instance=lambda instance: [_inv_dataset_map[instance]],
1822+
choices=tuple(_datasets),
1823+
),
1824+
tyro.conf.arg(
1825+
help_behavior_hint=lambda df: f"(default: {df}, run datasets.py for full options)"
1826+
),
1827+
]
1828+
1829+
1830+
def test_annotated_attribute_inheritance() -> None:
1831+
"""From @mirceamironenco.
1832+
1833+
https://github.com/brentyi/tyro/issues/239"""
1834+
1835+
@dataclasses.dataclass(frozen=True)
1836+
class TrainConfig:
1837+
dataset: str = "vicgalle/alpaca-gpt4"
1838+
1839+
@dataclasses.dataclass(frozen=True)
1840+
class CLITrainerConfig(TrainConfig):
1841+
dataset: HFDataset = "vicgalle/alpaca-gpt4"
1842+
1843+
assert "{alpaca,alpaca_clean,alpaca_gpt4}" in get_helptext_with_checks(
1844+
CLITrainerConfig
1845+
)
1846+
assert (
1847+
"default: alpaca_gpt4, run datasets.py for full options"
1848+
in get_helptext_with_checks(CLITrainerConfig)
1849+
)
1850+
1851+
1852+
@dataclasses.dataclass(frozen=True)
1853+
class OptimizerConfig:
1854+
lr: float = 1e-1
1855+
1856+
1857+
@dataclasses.dataclass(frozen=True)
1858+
class AdamConfig(OptimizerConfig):
1859+
adam_foo: float = 1.0
1860+
1861+
1862+
@dataclasses.dataclass(frozen=True)
1863+
class SGDConfig(OptimizerConfig):
1864+
sgd_foo: float = 1.0
1865+
1866+
1867+
@dataclasses.dataclass
1868+
class TrainConfig:
1869+
optimizer: OptimizerConfig = AdamConfig()
1870+
1871+
1872+
def _dummy_constructor() -> Type[OptimizerConfig]:
1873+
return AdamConfig | SGDConfig # type: ignore
1874+
1875+
1876+
CLIOptimizerConfig = Annotated[
1877+
OptimizerConfig,
1878+
tyro.conf.arg(constructor_factory=_dummy_constructor),
1879+
]
1880+
1881+
1882+
def test_attribute_inheritance_2() -> None:
1883+
"""From @mirceamironenco.
1884+
1885+
https://github.com/brentyi/tyro/issues/239"""
1886+
1887+
@dataclasses.dataclass
1888+
class CLITrainerConfig(TrainConfig):
1889+
optimizer: CLIOptimizerConfig = SGDConfig()
1890+
1891+
assert "[{optimizer:adam-config,optimizer:sgd-config}]" in get_helptext_with_checks(
1892+
CLITrainerConfig
1893+
)

0 commit comments

Comments
 (0)