Skip to content

New method to get link targets #715

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ Added
(`#698 <https://github.com/omni-us/jsonargparse/pull/698>`__).
- Option to enable validation of default values (`#711
<https://github.com/omni-us/jsonargparse/pull/711>`__).
- New method to get a list of link targets (`#715
<https://github.com/omni-us/jsonargparse/pull/715>`__).

Changed
^^^^^^^
Expand Down
4 changes: 4 additions & 0 deletions jsonargparse/_link_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,3 +515,7 @@ def link_arguments(
ValueError: If an invalid parameter is given.
"""
ActionLink(self, source, target, compute_fn, apply_on)

def get_link_targets(self, apply_on: str) -> List[str]:
"""Get all keys that are targets of links."""
return [a.target[0] for a in get_link_actions(self, apply_on)] # type: ignore[arg-type]
20 changes: 20 additions & 0 deletions jsonargparse_tests/test_link_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def test_on_parse_shallow_print_config(parser):
parser.link_arguments("a", "b")
out = get_parse_args_stdout(parser, ["--print_config"])
assert json_or_yaml_load(out) == {"a": 0}
assert parser.get_link_targets("parse") == ["b"]


def test_on_parse_subcommand_failing_compute_fn(parser, subparser, subtests):
Expand Down Expand Up @@ -107,6 +108,9 @@ def test_on_parse_compute_fn_subclass_spec(parser, subtests):
assert cfg.cal1.init_args.firstweekday == 2
assert cfg.cal2.init_args.firstweekday == 3

with subtests.test("get_link_targets"):
assert parser.get_link_targets("parse") == ["cal2.init_args.firstweekday"]

with subtests.test("invalid init parameter"):
parser.set_defaults(cal1=None)
with pytest.raises(ArgumentError) as ctx:
Expand Down Expand Up @@ -164,6 +168,9 @@ def test_on_parse_add_class_arguments(subtests):
dump = json_or_yaml_load(parser.dump(cfg, skip_link_targets=False))
assert dump == {"a": {"v1": 11, "v2": 7}, "b": {"v3": 2, "v1": 7, "v2": 18}}

with subtests.test("get_link_targets"):
assert parser.get_link_targets("parse") == ["b.v1", "b.v2"]

with subtests.test("argument error"):
pytest.raises(ArgumentError, lambda: parser.parse_args(["--b.v1=5"]))

Expand Down Expand Up @@ -209,6 +216,9 @@ def add(v1, v2):
dump = json_or_yaml_load(parser.dump(cfg, skip_link_targets=False))
assert dump["s2"] == {"class_path": f"{__name__}.ClassS2", "init_args": {"v3": 4}}

with subtests.test("get_link_targets"):
assert parser.get_link_targets("parse") == ["s2.init_args.v3"]

with subtests.test("compute_fn invalid result type"):
s1_value["init_args"] = {"v1": "a", "v2": "b"}
with pytest.raises(ArgumentError):
Expand Down Expand Up @@ -237,6 +247,7 @@ def test_on_parse_subclass_target_in_union(parser):
cfg = parser.parse_args(["--trainer.save_dir=logs", "--trainer.logger=Logger"])
assert cfg.trainer.save_dir == "logs"
assert cfg.trainer.logger.init_args == Namespace(save_dir="logs")
assert parser.get_link_targets("parse") == ["trainer.logger.init_args.save_dir"]


class TrainerLoggerList:
Expand Down Expand Up @@ -530,6 +541,7 @@ def test_on_instantiate_link_instance_attribute():
init = parser.instantiate_classes(cfg)
assert init.x.x1 == 6
assert init.y.y3 == '"8"'
assert parser.get_link_targets("instantiate") == ["x.x1", "y.y1", "y.y3"]


def test_on_instantiate_link_all_group_arguments():
Expand All @@ -542,6 +554,7 @@ def test_on_instantiate_link_all_group_arguments():
assert init["x"].x2 == 7
help_str = get_parser_help(parser)
assert "Group 'x': All arguments are derived from links" in help_str
assert parser.get_link_targets("instantiate") == ["x.x1", "y.y1", "x.x2"]


class FailingComputeFn1:
Expand Down Expand Up @@ -600,6 +613,7 @@ def test_on_parse_and_instantiate_link_entire_instance(parser):
assert isinstance(init.n, Namespace)
assert isinstance(init.c, Calendar)
assert init.c is init.n.calendar
assert parser.get_link_targets("instantiate") == ["n.calendar"]


class ClassM:
Expand Down Expand Up @@ -645,6 +659,7 @@ def test_on_instantiate_link_object_in_attribute(parser):
init = parser.instantiate_classes(cfg)
assert init.p.calendar is init.q.calendar
assert init.q.calendar.firstweekday == 2
assert parser.get_link_targets("instantiate") == ["q.calendar"]


def test_on_parse_link_entire_subclass(parser):
Expand All @@ -656,6 +671,7 @@ def test_on_parse_link_entire_subclass(parser):
cfg = parser.parse_args([f"--n.calendar={json.dumps(cal)}", "--q.q2=7"])
assert cfg.n.calendar == cfg.q.calendar
assert cfg.q.q2 == 7
assert parser.get_link_targets("parse") == ["q.calendar"]


class ClassV:
Expand Down Expand Up @@ -789,6 +805,7 @@ def test_on_instantiate_within_deep_subclass(parser, caplog):
assert isinstance(init.model.decoder, WithinDeepTarget)
assert init.model.decoder.input_channels == 16
assert "Applied link 'encoder.output_channels --> decoder.init_args.input_channels'" in caplog.text
assert parser.get_link_targets("instantiate") == ["model.init_args.decoder.init_args.input_channels"]


class WithinDeeperSystem:
Expand Down Expand Up @@ -824,6 +841,9 @@ def test_on_instantiate_within_deeper_subclass(parser, caplog):
assert isinstance(init.system.model.decoder, WithinDeepTarget)
assert init.system.model.decoder.input_channels == 16
assert "Applied link 'encoder.output_channels --> decoder.init_args.input_channels'" in caplog.text
assert parser.get_link_targets("instantiate") == [
"system.init_args.model.init_args.decoder.init_args.input_channels"
]


class SourceA:
Expand Down