Skip to content

Commit 992909d

Browse files
authored
Fix custom constructor edge case for variable-length positional arguments (#258)
* Fix custom constructor edge case for variable-length positional arguments * Python 3.8
1 parent 4ad01de commit 992909d

File tree

3 files changed

+73
-0
lines changed

3 files changed

+73
-0
lines changed

src/tyro/_calling.py

+5
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def get_value_from_arg(
9797
consumed_keywords.add(name_maybe_prefixed)
9898
if not arg.lowered.is_fixed():
9999
value, value_found = get_value_from_arg(name_maybe_prefixed, arg)
100+
should_cast = False
100101

101102
if value in _fields.MISSING_AND_MISSING_NONPROP:
102103
value = arg.field.default
@@ -114,14 +115,18 @@ def get_value_from_arg(
114115
and arg.lowered.nargs in ("?", "*")
115116
):
116117
value = []
118+
should_cast = True
117119
elif value_found:
118120
# Value was found from the CLI, so we need to cast it with instance_from_str.
121+
should_cast = True
119122
any_arguments_provided = True
120123
if arg.lowered.nargs == "?":
121124
# Special case for optional positional arguments: this is the
122125
# only time that arguments don't come back as a list.
123126
value = [value]
124127

128+
# Attempt to cast the value to the correct type.
129+
if should_cast:
125130
try:
126131
assert arg.lowered.instance_from_str is not None
127132
value = arg.lowered.instance_from_str(value)

tests/test_custom_constructors.py

+34
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any, Dict, List, Union
55

66
import numpy as np
7+
import pytest
78
from typing_extensions import Annotated, Literal, get_args
89

910
import tyro
@@ -102,3 +103,36 @@ def main(
102103
).dtype
103104
== np.float32
104105
)
106+
107+
108+
def make_list_of_strings_with_minimum_length(args: List[str]) -> List[str]:
109+
if len(args) == 0:
110+
raise ValueError("Expected at least one string")
111+
return args
112+
113+
114+
ListOfStringsWithMinimumLength = Annotated[
115+
List[str],
116+
tyro.constructors.PrimitiveConstructorSpec(
117+
nargs="*",
118+
metavar="STR [STR ...]",
119+
is_instance=lambda x: isinstance(x, list)
120+
and all(isinstance(i, str) for i in x),
121+
instance_from_str=make_list_of_strings_with_minimum_length,
122+
str_from_instance=lambda args: args,
123+
),
124+
]
125+
126+
127+
def test_min_length_custom_constructor() -> None:
128+
def main(
129+
field1: ListOfStringsWithMinimumLength, field2: int = 3
130+
) -> ListOfStringsWithMinimumLength:
131+
del field2
132+
return field1
133+
134+
with pytest.raises(SystemExit):
135+
tyro.cli(main, args=[])
136+
with pytest.raises(SystemExit):
137+
tyro.cli(main, args=["--field1"])
138+
assert tyro.cli(main, args=["--field1", "a", "b"]) == ["a", "b"]

tests/test_py311_generated/test_custom_constructors_generated.py

+34
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Annotated, Any, Dict, List, Literal, get_args
55

66
import numpy as np
7+
import pytest
78

89
import tyro
910

@@ -101,3 +102,36 @@ def main(
101102
).dtype
102103
== np.float32
103104
)
105+
106+
107+
def make_list_of_strings_with_minimum_length(args: List[str]) -> List[str]:
108+
if len(args) == 0:
109+
raise ValueError("Expected at least one string")
110+
return args
111+
112+
113+
ListOfStringsWithMinimumLength = Annotated[
114+
List[str],
115+
tyro.constructors.PrimitiveConstructorSpec(
116+
nargs="*",
117+
metavar="STR [STR ...]",
118+
is_instance=lambda x: isinstance(x, list)
119+
and all(isinstance(i, str) for i in x),
120+
instance_from_str=make_list_of_strings_with_minimum_length,
121+
str_from_instance=lambda args: args,
122+
),
123+
]
124+
125+
126+
def test_min_length_custom_constructor() -> None:
127+
def main(
128+
field1: ListOfStringsWithMinimumLength, field2: int = 3
129+
) -> ListOfStringsWithMinimumLength:
130+
del field2
131+
return field1
132+
133+
with pytest.raises(SystemExit):
134+
tyro.cli(main, args=[])
135+
with pytest.raises(SystemExit):
136+
tyro.cli(main, args=["--field1"])
137+
assert tyro.cli(main, args=["--field1", "a", "b"]) == ["a", "b"]

0 commit comments

Comments
 (0)