diff --git a/CHANGELOG.rst b/CHANGELOG.rst index ef59b6d3..873f1044 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -27,6 +27,8 @@ Added (`#698 `__). - Option to enable validation of default values (`#711 `__). +- New method to get a list of link targets (`#715 + `__). Changed ^^^^^^^ diff --git a/jsonargparse/_link_arguments.py b/jsonargparse/_link_arguments.py index 2c08b0ba..7c37efc0 100644 --- a/jsonargparse/_link_arguments.py +++ b/jsonargparse/_link_arguments.py @@ -515,3 +515,7 @@ def link_arguments( ValueError: If an invalid parameter is given. """ ActionLink(self, source, target, compute_fn, apply_on) + + def get_link_targets(self, apply_on: str) -> List[str]: + """Get all keys that are targets of links.""" + return [a.target[0] for a in get_link_actions(self, apply_on)] # type: ignore[arg-type] diff --git a/jsonargparse_tests/test_link_arguments.py b/jsonargparse_tests/test_link_arguments.py index c4e076eb..da786e37 100644 --- a/jsonargparse_tests/test_link_arguments.py +++ b/jsonargparse_tests/test_link_arguments.py @@ -35,6 +35,7 @@ def test_on_parse_shallow_print_config(parser): parser.link_arguments("a", "b") out = get_parse_args_stdout(parser, ["--print_config"]) assert json_or_yaml_load(out) == {"a": 0} + assert parser.get_link_targets("parse") == ["b"] def test_on_parse_subcommand_failing_compute_fn(parser, subparser, subtests): @@ -107,6 +108,9 @@ def test_on_parse_compute_fn_subclass_spec(parser, subtests): assert cfg.cal1.init_args.firstweekday == 2 assert cfg.cal2.init_args.firstweekday == 3 + with subtests.test("get_link_targets"): + assert parser.get_link_targets("parse") == ["cal2.init_args.firstweekday"] + with subtests.test("invalid init parameter"): parser.set_defaults(cal1=None) with pytest.raises(ArgumentError) as ctx: @@ -164,6 +168,9 @@ def test_on_parse_add_class_arguments(subtests): dump = json_or_yaml_load(parser.dump(cfg, skip_link_targets=False)) assert dump == {"a": {"v1": 11, "v2": 7}, "b": {"v3": 2, "v1": 7, "v2": 18}} + with subtests.test("get_link_targets"): + assert parser.get_link_targets("parse") == ["b.v1", "b.v2"] + with subtests.test("argument error"): pytest.raises(ArgumentError, lambda: parser.parse_args(["--b.v1=5"])) @@ -209,6 +216,9 @@ def add(v1, v2): dump = json_or_yaml_load(parser.dump(cfg, skip_link_targets=False)) assert dump["s2"] == {"class_path": f"{__name__}.ClassS2", "init_args": {"v3": 4}} + with subtests.test("get_link_targets"): + assert parser.get_link_targets("parse") == ["s2.init_args.v3"] + with subtests.test("compute_fn invalid result type"): s1_value["init_args"] = {"v1": "a", "v2": "b"} with pytest.raises(ArgumentError): @@ -237,6 +247,7 @@ def test_on_parse_subclass_target_in_union(parser): cfg = parser.parse_args(["--trainer.save_dir=logs", "--trainer.logger=Logger"]) assert cfg.trainer.save_dir == "logs" assert cfg.trainer.logger.init_args == Namespace(save_dir="logs") + assert parser.get_link_targets("parse") == ["trainer.logger.init_args.save_dir"] class TrainerLoggerList: @@ -530,6 +541,7 @@ def test_on_instantiate_link_instance_attribute(): init = parser.instantiate_classes(cfg) assert init.x.x1 == 6 assert init.y.y3 == '"8"' + assert parser.get_link_targets("instantiate") == ["x.x1", "y.y1", "y.y3"] def test_on_instantiate_link_all_group_arguments(): @@ -542,6 +554,7 @@ def test_on_instantiate_link_all_group_arguments(): assert init["x"].x2 == 7 help_str = get_parser_help(parser) assert "Group 'x': All arguments are derived from links" in help_str + assert parser.get_link_targets("instantiate") == ["x.x1", "y.y1", "x.x2"] class FailingComputeFn1: @@ -600,6 +613,7 @@ def test_on_parse_and_instantiate_link_entire_instance(parser): assert isinstance(init.n, Namespace) assert isinstance(init.c, Calendar) assert init.c is init.n.calendar + assert parser.get_link_targets("instantiate") == ["n.calendar"] class ClassM: @@ -645,6 +659,7 @@ def test_on_instantiate_link_object_in_attribute(parser): init = parser.instantiate_classes(cfg) assert init.p.calendar is init.q.calendar assert init.q.calendar.firstweekday == 2 + assert parser.get_link_targets("instantiate") == ["q.calendar"] def test_on_parse_link_entire_subclass(parser): @@ -656,6 +671,7 @@ def test_on_parse_link_entire_subclass(parser): cfg = parser.parse_args([f"--n.calendar={json.dumps(cal)}", "--q.q2=7"]) assert cfg.n.calendar == cfg.q.calendar assert cfg.q.q2 == 7 + assert parser.get_link_targets("parse") == ["q.calendar"] class ClassV: @@ -789,6 +805,7 @@ def test_on_instantiate_within_deep_subclass(parser, caplog): assert isinstance(init.model.decoder, WithinDeepTarget) assert init.model.decoder.input_channels == 16 assert "Applied link 'encoder.output_channels --> decoder.init_args.input_channels'" in caplog.text + assert parser.get_link_targets("instantiate") == ["model.init_args.decoder.init_args.input_channels"] class WithinDeeperSystem: @@ -824,6 +841,9 @@ def test_on_instantiate_within_deeper_subclass(parser, caplog): assert isinstance(init.system.model.decoder, WithinDeepTarget) assert init.system.model.decoder.input_channels == 16 assert "Applied link 'encoder.output_channels --> decoder.init_args.input_channels'" in caplog.text + assert parser.get_link_targets("instantiate") == [ + "system.init_args.model.init_args.decoder.init_args.input_channels" + ] class SourceA: