Skip to content

Commit 3c6034e

Browse files
authored
Merge pull request #38 from csdms/mcflugen/add-parameters-to-stage
Add parameters keyword to stage function
2 parents 6925e55 + 1b054f1 commit 3c6034e

File tree

3 files changed

+67
-2
lines changed

3 files changed

+67
-2
lines changed

src/model_metadata/api.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from __future__ import annotations
22

3+
from collections.abc import Iterable
4+
from typing import Any
5+
6+
from model_metadata.errors import UnknownKeyError
37
from model_metadata.model_setup import FileSystemLoader
48
from model_metadata.model_setup import OldFileSystemLoader
59
from model_metadata.modelmetadata import ModelMetadata
@@ -52,7 +56,10 @@ def query(model: str, var: str) -> ModelMetadata:
5256

5357

5458
def stage(
55-
model: str, dest: str = ".", old_style_templates: bool = False
59+
model: str,
60+
dest: str = ".",
61+
old_style_templates: bool = False,
62+
parameters: dict[str, Any] | None = None,
5663
) -> tuple[str, ...]:
5764
"""Stage a model by setting up its input files.
5865
@@ -62,18 +69,39 @@ def stage(
6269
The model is interpreted either as a path to a folder that
6370
contains metadata, the name of a model component, or a
6471
model object.
65-
dest : str
72+
dest : str, optional
6673
Path to a folder within which to stage the model.
74+
parameters : dict[str, Any], optional
75+
A dictionary of parameters that overrides the default
76+
values.
6777
"""
78+
parameters = {} if parameters is None else parameters
79+
6880
defaults = {}
6981
mmd = ModelMetadata.find(model)
7082
meta = ModelMetadata(mmd)
7183
for param, item in meta.parameters.items():
7284
defaults[param] = item["value"]["default"]
7385

86+
try:
87+
_check_for_unknown_keys(defaults.keys(), parameters.keys())
88+
except UnknownKeyError as e:
89+
e.add_note(
90+
f"valid parameter{'s' if len(defaults) > 1 else ''}:"
91+
f" {', '.join(sorted(repr(k) for k in defaults))}"
92+
)
93+
raise
94+
95+
defaults = {**defaults, **parameters}
96+
7497
if old_style_templates:
7598
manifest = OldFileSystemLoader(mmd).stage_all(dest, **defaults)
7699
else:
77100
manifest = FileSystemLoader(mmd).stage_all(dest, **defaults)
78101

79102
return manifest
103+
104+
105+
def _check_for_unknown_keys(allowed: Iterable[str], user: Iterable[str]) -> None:
106+
if unknown_keys := (set(user) - set(allowed)):
107+
raise UnknownKeyError(unknown_keys)

src/model_metadata/errors.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#! /usr/bin/env python
22
from __future__ import annotations
33

4+
from collections.abc import Iterable
5+
46

57
class ModelMetadataError(Exception):
68
"""Base error for model_metadata package."""
@@ -55,3 +57,16 @@ def __init__(self, entry_point: str, msg: str | None = None):
5557

5658
def __str__(self) -> str:
5759
return self._entry_point + f": {self._msg}" if self._msg else ""
60+
61+
62+
class UnknownKeyError(ModelMetadataError):
63+
"""Raise if a dictionary contains one or more unrecognized keys."""
64+
65+
def __init__(self, unknown: Iterable[str]) -> None:
66+
super().__init__(*sorted(set(unknown)))
67+
68+
def __str__(self) -> str:
69+
return (
70+
f"unknown key{'s' if len(self.args) > 1 else ''}:"
71+
f" {', '.join(repr(key) for key in self.args)}"
72+
)

tests/api_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import itertools
34
import os
45
import pathlib
56

@@ -10,6 +11,7 @@
1011
from model_metadata.errors import MetadataNotFoundError
1112
from model_metadata.errors import MissingSectionError
1213
from model_metadata.errors import MissingValueError
14+
from model_metadata.errors import UnknownKeyError
1315

1416

1517
class Model:
@@ -85,6 +87,26 @@ def test_stage_is_filled_out(tmpdir, shared_datadir):
8587
assert "}}" not in contents
8688

8789

90+
def test_stage_with_parameters(tmpdir, shared_datadir):
91+
with tmpdir.as_cwd():
92+
stage(str(shared_datadir), parameters={"run_duration": 999})
93+
94+
file_contains_runtime = False
95+
with open("child.in") as fp:
96+
for this, next_ in itertools.pairwise(fp):
97+
file_contains_runtime = this.startswith("RUNTIME")
98+
if file_contains_runtime:
99+
assert next_.startswith("999")
100+
break
101+
assert file_contains_runtime
102+
103+
104+
@pytest.mark.parametrize("params", ({"foo": "bar"}, {"foo": "bar", "baz": "foobar"}))
105+
def test_stage_with_unknown_parameters(tmpdir, shared_datadir, params):
106+
with tmpdir.as_cwd(), pytest.raises(UnknownKeyError):
107+
stage(str(shared_datadir), parameters=params)
108+
109+
88110
def test_query_with_bad_section(shared_datadir):
89111
with pytest.raises(MissingSectionError):
90112
query(str(shared_datadir), "not-a-section.version")

0 commit comments

Comments
 (0)